Repository: rainmaker22/SMART Branch: main Commit: 42e658542b03 Files: 52 Total size: 205.4 KB Directory structure: gitextract_98xwsw9y/ ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── configs/ │ ├── train/ │ │ └── train_scalable.yaml │ └── validation/ │ └── validation_scalable.yaml ├── data_preprocess.py ├── environment.yml ├── pyproject.toml ├── requirements.txt ├── scripts/ │ ├── install_pyg.sh │ └── traj_clstering.py ├── smart/ │ ├── __init__.py │ ├── datamodules/ │ │ ├── __init__.py │ │ └── scalable_datamodule.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── preprocess.py │ │ └── scalable_dataset.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── attention_layer.py │ │ ├── fourier_embedding.py │ │ └── mlp_layer.py │ ├── metrics/ │ │ ├── __init__.py │ │ ├── average_meter.py │ │ ├── min_ade.py │ │ ├── min_fde.py │ │ ├── next_token_cls.py │ │ └── utils.py │ ├── model/ │ │ ├── __init__.py │ │ └── smart.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── agent_decoder.py │ │ ├── map_decoder.py │ │ └── smart_decoder.py │ ├── preprocess/ │ │ ├── __init__.py │ │ └── preprocess.py │ ├── tokens/ │ │ ├── __init__.py │ │ ├── cluster_frame_5_2048.pkl │ │ └── map_traj_token5.pkl │ ├── transforms/ │ │ ├── __init__.py │ │ └── target_builder.py │ └── utils/ │ ├── __init__.py │ ├── cluster_reader.py │ ├── config.py │ ├── geometry.py │ ├── graph.py │ ├── list.py │ ├── log.py │ ├── nan_checker.py │ └── weight_init.py ├── train.py └── val.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class .github ckpt/ # assets/ # C extensions *.so # /assets /data # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv *.jpg env/ venv/ ENV/ env.bak/ venv.bak/ *.jpg pyg_depend/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ # IDEs .idea .vscode # seed project av2/ lightning_logs/ lightning_logs_/ lightning_l/ .DS_Store data/argo data/res data/waymo* fig*/ data/waymo_token data/submission data/token_seq_emb_nuplan data/token_seq_emb_waymo data/nuplan* submission.tar.gz data/feat* data/scalable data/pos_data res_metrics* gathered* ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================
# SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction [Paper](https://arxiv.org/abs/2405.15677) | [Webpage](https://smart-motion.github.io/smart/)
- **Ranked 1st** on the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) - **Champion** of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/) ## News - **[December 31, 2024]** SMART-Planner achieved state-of-the-art performance on **nuPlan closed-loop planning** - **[September 26, 2024]** SMART was **accepted to** NeurIPS 2024 - **[August 31, 2024]** Code released - **[May 24, 2024]** SMART won the championship of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/) - **[May 24, 2024]** SMART paper released on [arxiv](https://arxiv.org/abs/2405.15677) ## Introduction This repository contains the official implementation of SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction. SMART is a novel autonomous driving motion generation paradigm that models vectorized map and agent trajectory data into discrete sequence tokens. https://github.com/user-attachments/assets/74a61627-8444-4e54-bb10-d317dd2aacd9 ## Requirements To set up the environment, you can use conda to create and activate a new environment with the necessary dependencies: ```bash conda env create -f environment.yml conda activate SMART pip install -r requirements.txt ``` If you encounter issues while installing pyg dependencies, execute the following script: ```setup bash install_pyg.sh ``` Alternatively, you can configure the environment in your preferred way. Installing the latest versions of PyTorch, PyG, and PyTorch Lightning should suffice. ## Data installation **Step 1: Download the Dataset** Download the Waymo Open Motion Dataset (`scenario protocol` format) and organize the data as follows: ``` SMART ├── data │ ├── waymo │ │ ├── scenario │ │ │ ├──training │ │ │ ├──validation │ │ │ ├──testing ├── model ├── tools ``` **Step 2: Install the Waymo Open Dataset API** Follow the instructions [here](https://github.com/waymo-research/waymo-open-dataset) to install the Waymo Open Dataset API. **Step 3: Preprocess the Dataset** Preprocess the dataset by running: ``` python data_preprocess.py --input_dir ./data/waymo/scenario/training --output_dir ./data/waymo_processed/training ``` The first path is the raw data path, and the second is the output data path. The processed data will be saved to the `data/waymo_processed/` directory as follows: ``` SMART ├── data │ ├── waymo_processed │ │ ├── training │ │ ├── validation │ │ ├──testing ├── model ├── utils ``` ## Training To train the model, run the following command: ```train python train.py --config ${config_path} ``` The default config path is `configs/train/train_scalable.yaml`. Ensure you have downloaded and prepared the Waymo data for training. ## Evaluation To evaluate the model, run: ```eval python eval.py --config ${config_path} --pretrain_ckpt ${ckpt_path} ``` This will evaluate the model using the configuration and checkpoint provided. ## Pre-trained Models To comply with the WOMD participation agreement, we will release the model parameters of a medium-sized model not trained on Waymo data. Users can fine-tune this model with Waymo data as needed. ## Results ### Waymo Open Motion Dataset Sim Agents Challenge Our model achieves the following performance on the [Waymo Open Motion Dataset Sim Agents Challenge](https://waymo.com/open/challenges/2024/sim-agents/): | Model name | Metric Score | | :-----------: | ------------ | | SMART-tiny | 0.7591 | | SMART-large | 0.7614 | | SMART-zeroshot| 0.7210 | ### NuPlan Closed-loop Planning **SMART-Planner** achieved state-of-the-art performance among learning-based algorithms on **nuPlan closed-loop planning**. The results on val14 are shown below: ![nuPlan Closed-loop Planning](assets/result1.png) ## Citation If you find this repository useful, please consider citing our work and giving us a star: ```citation @article{wu2024smart, title={SMART: Scalable Multi-agent Real-time Simulation via Next-token Prediction}, author={Wu, Wei and Feng, Xiaoxin and Gao, Ziyan and Kan, Yuheng}, journal={arXiv preprint arXiv:2405.15677}, year={2024} } ``` ## Acknowledgements Special thanks to the [QCNET](https://github.com/ZikangZhou/QCNet) repository for providing valuable reference code that significantly influenced this work. ## License All code in this repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). ================================================ FILE: __init__.py ================================================ ================================================ FILE: configs/train/train_scalable.yaml ================================================ # Config format schema number, the yaml support to valid case source from different dataset time_info: &time_info num_historical_steps: 11 num_future_steps: 80 use_intention: True token_size: 2048 Dataset: root: train_batch_size: 1 val_batch_size: 1 test_batch_size: 1 shuffle: True num_workers: 1 pin_memory: True persistent_workers: True train_raw_dir: ["data/valid_demo"] val_raw_dir: ["data/valid_demo"] test_raw_dir: transform: WaymoTargetBuilder train_processed_dir: val_processed_dir: test_processed_dir: dataset: "scalable" <<: *time_info Trainer: strategy: ddp_find_unused_parameters_false accelerator: "gpu" devices: 1 max_epochs: 32 save_ckpt_path: num_nodes: 1 mode: ckpt_path: precision: 32 accumulate_grad_batches: 1 Model: mode: "train" predictor: "smart" dataset: "waymo" input_dim: 2 hidden_dim: 128 output_dim: 2 output_head: False num_heads: 8 <<: *time_info head_dim: 16 dropout: 0.1 num_freq_bands: 64 lr: 0.0005 warmup_steps: 0 total_steps: 32 decoder: <<: *time_info num_map_layers: 3 num_agent_layers: 6 a2a_radius: 60 pl2pl_radius: 10 pl2a_radius: 30 time_span: 30 ================================================ FILE: configs/validation/validation_scalable.yaml ================================================ # Config format schema number, the yaml support to valid case source from different dataset time_info: &time_info num_historical_steps: 11 num_future_steps: 80 token_size: 2048 Dataset: root: batch_size: 1 shuffle: True num_workers: 1 pin_memory: True persistent_workers: True train_raw_dir: val_raw_dir: ["data/valid_demo"] test_raw_dir: TargetBuilder: WaymoTargetBuilder train_processed_dir: val_processed_dir: test_processed_dir: dataset: "scalable" <<: *time_info Trainer: strategy: ddp_find_unused_parameters_false accelerator: "gpu" devices: 1 max_epochs: 32 save_ckpt_path: num_nodes: 1 mode: ckpt_path: precision: 32 accumulate_grad_batches: 1 Model: mode: "validation" predictor: "smart" dataset: "waymo" input_dim: 2 hidden_dim: 128 output_dim: 2 output_head: False num_heads: 8 <<: *time_info head_dim: 16 dropout: 0.1 num_freq_bands: 64 lr: 0.0005 warmup_steps: 0 total_steps: 32 decoder: <<: *time_info num_map_layers: 3 num_agent_layers: 6 a2a_radius: 60 pl2pl_radius: 10 pl2a_radius: 30 time_span: 30 ================================================ FILE: data_preprocess.py ================================================ import numpy as np import pandas as pd import os import torch import pickle from tqdm import tqdm from typing import Any, Dict, List, Optional import easydict predict_unseen_agents = False vector_repr = True root = '' split = 'train' raw_dir = os.path.join(root, split, 'raw') _raw_dir = raw_dir if os.path.isdir(_raw_dir): _raw_file_names = [name for name in os.listdir(_raw_dir)] else: _raw_file_names = [] processed_dir = os.path.join(root, split, 'processed') _processed_dir = processed_dir if os.path.isdir(_processed_dir): _processed_file_names = [name for name in os.listdir(_processed_dir) if name.endswith(('pkl', 'pickle'))] else: _processed_file_names = [] _agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background'] _polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN'] _polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN'] _point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW', 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE', 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE', 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE'] _point_sides = ['LEFT', 'RIGHT', 'CENTER'] _polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT'] _polygon_is_intersections = [True, False, None] Lane_type_hash = { 4: "BIKE", 3: "VEHICLE", 2: "VEHICLE", 1: "BUS" } boundary_type_hash = { 5: "UNKNOWN", 6: "DASHED_WHITE", 7: "SOLID_WHITE", 8: "DOUBLE_DASH_WHITE", 9: "DASHED_YELLOW", 10: "DOUBLE_DASH_YELLOW", 11: "SOLID_YELLOW", 12: "DOUBLE_SOLID_YELLOW", 13: "DASH_SOLID_YELLOW", 14: "UNKNOWN", 15: "EDGE", 16: "EDGE" } def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]: try: return ls.index(elem) except ValueError: return None def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]: if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps historical_df = df[df['timestep'] == num_historical_steps-1] agent_ids = list(historical_df['track_id'].unique()) df = df[df['track_id'].isin(agent_ids)] else: agent_ids = list(df['track_id'].unique()) num_agents = len(agent_ids) # initialization valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) current_valid_mask = torch.zeros(num_agents, dtype=torch.bool) predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) agent_id: List[Optional[str]] = [None] * num_agents agent_type = torch.zeros(num_agents, dtype=torch.uint8) agent_category = torch.zeros(num_agents, dtype=torch.uint8) position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) heading = torch.zeros(num_agents, num_steps, dtype=torch.float) velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) for track_id, track_df in df.groupby('track_id'): agent_idx = agent_ids.index(track_id) agent_steps = track_df['timestep'].values valid_mask[agent_idx, agent_steps] = True current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1] predict_mask[agent_idx, agent_steps] = True if vector_repr: # a time step t is valid only when both t and t-1 are valid valid_mask[agent_idx, 1: num_historical_steps] = ( valid_mask[agent_idx, :num_historical_steps - 1] & valid_mask[agent_idx, 1: num_historical_steps]) valid_mask[agent_idx, 0] = False predict_mask[agent_idx, :num_historical_steps] = False if not current_valid_mask[agent_idx]: predict_mask[agent_idx, num_historical_steps:] = False agent_id[agent_idx] = track_id agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0]) agent_category[agent_idx] = track_df['object_category'].values[0] position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values, track_df['position_y'].values, track_df['position_z'].values], axis=-1)).float() heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float() velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values, track_df['velocity_y'].values], axis=-1)).float() shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values, track_df['width'].values, track_df["height"].values], axis=-1)).float() av_idx = agent_id.index(av_id) if split == 'test': predict_mask[current_valid_mask | (agent_category == 2) | (agent_category == 3), num_historical_steps:] = True return { 'num_nodes': num_agents, 'av_index': av_idx, 'valid_mask': valid_mask, 'predict_mask': predict_mask, 'id': agent_id, 'type': agent_type, 'category': agent_category, 'position': position, 'heading': heading, 'velocity': velocity, 'shape': shape } def get_map_features(map_infos, tf_current_light, dim=3): lane_segments = map_infos['lane'] all_polylines = map_infos["all_polylines"] crosswalks = map_infos['crosswalk'] road_edges = map_infos['road_edge'] road_lines = map_infos['road_line'] lane_segment_ids = [info["id"] for info in lane_segments] cross_walk_ids = [info["id"] for info in crosswalks] road_edge_ids = [info["id"] for info in road_edges] road_line_ids = [info["id"] for info in road_lines] polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids) # initialization polygon_type = torch.zeros(num_polygons, dtype=torch.uint8) polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3 point_position: List[Optional[torch.Tensor]] = [None] * num_polygons point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons point_height: List[Optional[torch.Tensor]] = [None] * num_polygons point_type: List[Optional[torch.Tensor]] = [None] * num_polygons for lane_segment in lane_segments: lane_segment = easydict.EasyDict(lane_segment) lane_segment_idx = polygon_ids.index(lane_segment.id) polyline_index = lane_segment.polyline_index centerline = all_polylines[polyline_index[0]:polyline_index[1], :] centerline = torch.from_numpy(centerline).float() polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type]) res = tf_current_light[tf_current_light["lane_id"] == str(lane_segment.id)] if len(res) != 0: polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res["state"].item()) point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) center_vectors = centerline[1:] - centerline[:-1] point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) center_type = _point_types.index('CENTERLINE') point_type[lane_segment_idx] = torch.cat( [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) for lane_segment in road_edges: lane_segment = easydict.EasyDict(lane_segment) lane_segment_idx = polygon_ids.index(lane_segment.id) polyline_index = lane_segment.polyline_index centerline = all_polylines[polyline_index[0]:polyline_index[1], :] centerline = torch.from_numpy(centerline).float() polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE") point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) center_vectors = centerline[1:] - centerline[:-1] point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) center_type = _point_types.index('EDGE') point_type[lane_segment_idx] = torch.cat( [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) for lane_segment in road_lines: lane_segment = easydict.EasyDict(lane_segment) lane_segment_idx = polygon_ids.index(lane_segment.id) polyline_index = lane_segment.polyline_index centerline = all_polylines[polyline_index[0]:polyline_index[1], :] centerline = torch.from_numpy(centerline).float() polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE") point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) center_vectors = centerline[1:] - centerline[:-1] point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) center_type = _point_types.index(boundary_type_hash[lane_segment.type]) point_type[lane_segment_idx] = torch.cat( [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) for crosswalk in crosswalks: crosswalk = easydict.EasyDict(crosswalk) lane_segment_idx = polygon_ids.index(crosswalk.id) polyline_index = crosswalk.polyline_index centerline = all_polylines[polyline_index[0]:polyline_index[1], :] centerline = torch.from_numpy(centerline).float() polygon_type[lane_segment_idx] = _polygon_types.index("PEDESTRIAN") point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0) center_vectors = centerline[1:] - centerline[:-1] point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0) point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1) point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0) center_type = _point_types.index("CROSSWALK") point_type[lane_segment_idx] = torch.cat( [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0) num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long) point_to_polygon_edge_index = torch.stack( [torch.arange(num_points.sum(), dtype=torch.long), torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0) polygon_to_polygon_edge_index = [] polygon_to_polygon_type = [] for lane_segment in lane_segments: lane_segment = easydict.EasyDict(lane_segment) lane_segment_idx = polygon_ids.index(lane_segment.id) pred_inds = [] for pred in lane_segment.entry_lanes: pred_idx = safe_list_index(polygon_ids, pred) if pred_idx is not None: pred_inds.append(pred_idx) if len(pred_inds) != 0: polygon_to_polygon_edge_index.append( torch.stack([torch.tensor(pred_inds, dtype=torch.long), torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0)) polygon_to_polygon_type.append( torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8)) succ_inds = [] for succ in lane_segment.exit_lanes: succ_idx = safe_list_index(polygon_ids, succ) if succ_idx is not None: succ_inds.append(succ_idx) if len(succ_inds) != 0: polygon_to_polygon_edge_index.append( torch.stack([torch.tensor(succ_inds, dtype=torch.long), torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0)) polygon_to_polygon_type.append( torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8)) if len(lane_segment.left_neighbors) != 0: left_neighbor_ids = lane_segment.left_neighbors for left_neighbor_id in left_neighbor_ids: left_idx = safe_list_index(polygon_ids, left_neighbor_id) if left_idx is not None: polygon_to_polygon_edge_index.append( torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long)) polygon_to_polygon_type.append( torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8)) if len(lane_segment.right_neighbors) != 0: right_neighbor_ids = lane_segment.right_neighbors for right_neighbor_id in right_neighbor_ids: right_idx = safe_list_index(polygon_ids, right_neighbor_id) if right_idx is not None: polygon_to_polygon_edge_index.append( torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long)) polygon_to_polygon_type.append( torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8)) if len(polygon_to_polygon_edge_index) != 0: polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1) polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0) else: polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long) polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8) map_data = { 'map_polygon': {}, 'map_point': {}, ('map_point', 'to', 'map_polygon'): {}, ('map_polygon', 'to', 'map_polygon'): {}, } map_data['map_polygon']['num_nodes'] = num_polygons map_data['map_polygon']['type'] = polygon_type map_data['map_polygon']['light_type'] = polygon_light_type if len(num_points) == 0: map_data['map_point']['num_nodes'] = 0 map_data['map_point']['position'] = torch.tensor([], dtype=torch.float) map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float) map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float) if dim == 3: map_data['map_point']['height'] = torch.tensor([], dtype=torch.float) map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8) map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8) else: map_data['map_point']['num_nodes'] = num_points.sum().item() map_data['map_point']['position'] = torch.cat(point_position, dim=0) map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0) map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0) if dim == 3: map_data['map_point']['height'] = torch.cat(point_height, dim=0) map_data['map_point']['type'] = torch.cat(point_type, dim=0) map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type # import matplotlib.pyplot as plt # plt.axis('equal') # plt.scatter(map_data['map_point']['position'][:, 0], # map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none') # plt.show(dpi=600) return map_data def process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, start_timestamp, end_timestamp): agents_array = track_info["trajs"].transpose(1, 0, 2) object_id = np.array(track_info["object_id"]) object_type = track_info["object_type"] id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))} def type_hash(x): tp = id_hash[x] type_re_hash = { "TYPE_VEHICLE": "vehicle", "TYPE_PEDESTRIAN": "pedestrian", "TYPE_CYCLIST": "cyclist", "TYPE_OTHER": "background", "TYPE_UNSET": "background" } return type_re_hash[tp] columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep', 'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y', 'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps', 'focal_track_id', 'city'] new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11)) new_columns[:11, :, 0] = True new_columns[11:, :, 0] = False for index in range(new_columns.shape[0]): new_columns[index, :, 4] = int(index) new_columns[..., 1] = object_id new_columns[..., 2] = object_id new_columns[:, tracks_to_predict["track_index"], 3] = 3 new_columns[..., 5] = 11 new_columns[..., 6] = int(start_timestamp) new_columns[..., 7] = int(end_timestamp) new_columns[..., 8] = int(91) new_columns[..., 9] = object_id new_columns[..., 10] = 10086 new_columns = new_columns new_agents_array = np.concatenate([new_columns, agents_array], axis=-1) new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1]) new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10]] new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns) new_agents_array["object_type"] = new_agents_array["object_type"].apply(func=type_hash) new_agents_array["start_timestamp"] = new_agents_array["start_timestamp"].astype(int) new_agents_array["end_timestamp"] = new_agents_array["end_timestamp"].astype(int) new_agents_array["num_timestamps"] = new_agents_array["num_timestamps"].astype(int) new_agents_array["scenario_id"] = scenario_id return new_agents_array def process_dynamic_map(dynamic_map_infos): lane_ids = dynamic_map_infos["lane_id"] tf_lights = [] for t in range(len(lane_ids)): lane_id = lane_ids[t] time = np.ones_like(lane_id) * t state = dynamic_map_infos["state"][t] tf_light = np.concatenate([lane_id, time, state], axis=0) tf_lights.append(tf_light) tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0) tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"]) tf_lights["time_step"] = tf_lights["time_step"].astype("str") tf_lights["lane_id"] = tf_lights["lane_id"].astype("str") tf_lights["state"] = tf_lights["state"].astype("str") tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"] ] = 'LANE_STATE_STOP' tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"] ] = 'LANE_STATE_GO' tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"] ] = 'LANE_STATE_CAUTION' return tf_lights polyline_type = { # for lane 'TYPE_UNDEFINED': -1, 'TYPE_FREEWAY': 1, 'TYPE_SURFACE_STREET': 2, 'TYPE_BIKE_LANE': 3, # for roadline 'TYPE_UNKNOWN': -1, 'TYPE_BROKEN_SINGLE_WHITE': 6, 'TYPE_SOLID_SINGLE_WHITE': 7, 'TYPE_SOLID_DOUBLE_WHITE': 8, 'TYPE_BROKEN_SINGLE_YELLOW': 9, 'TYPE_BROKEN_DOUBLE_YELLOW': 10, 'TYPE_SOLID_SINGLE_YELLOW': 11, 'TYPE_SOLID_DOUBLE_YELLOW': 12, 'TYPE_PASSING_DOUBLE_YELLOW': 13, # for roadedge 'TYPE_ROAD_EDGE_BOUNDARY': 15, 'TYPE_ROAD_EDGE_MEDIAN': 16, # for stopsign 'TYPE_STOP_SIGN': 17, # for crosswalk 'TYPE_CROSSWALK': 18, # for speed bump 'TYPE_SPEED_BUMP': 19 } object_type = { 0: 'TYPE_UNSET', 1: 'TYPE_VEHICLE', 2: 'TYPE_PEDESTRIAN', 3: 'TYPE_CYCLIST', 4: 'TYPE_OTHER' } signal_state = { 0: 'LANE_STATE_UNKNOWN', # // States for traffic signals with arrows. 1: 'LANE_STATE_ARROW_STOP', 2: 'LANE_STATE_ARROW_CAUTION', 3: 'LANE_STATE_ARROW_GO', # // Standard round traffic signals. 4: 'LANE_STATE_STOP', 5: 'LANE_STATE_CAUTION', 6: 'LANE_STATE_GO', # // Flashing light signals. 7: 'LANE_STATE_FLASHING_STOP', 8: 'LANE_STATE_FLASHING_CAUTION' } signal_state_to_id = {} for key, val in signal_state.items(): signal_state_to_id[val] = key def decode_tracks_from_proto(tracks): track_infos = { 'object_id': [], # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others} 'object_type': [], 'trajs': [] } for cur_data in tracks: # number of objects cur_traj = [np.array([x.center_x, x.center_y, x.center_z, x.length, x.width, x.height, x.heading, x.velocity_x, x.velocity_y, x.valid], dtype=np.float32) for x in cur_data.states] cur_traj = np.stack(cur_traj, axis=0) # (num_timestamp, 10) track_infos['object_id'].append(cur_data.id) track_infos['object_type'].append(object_type[cur_data.object_type]) track_infos['trajs'].append(cur_traj) track_infos['trajs'] = np.stack(track_infos['trajs'], axis=0) # (num_objects, num_timestamp, 9) return track_infos from collections import defaultdict def decode_map_features_from_proto(map_features): map_infos = { 'lane': [], 'road_line': [], 'road_edge': [], 'stop_sign': [], 'crosswalk': [], 'speed_bump': [], 'lane_dict': {}, 'lane2other_dict': {} } polylines = [] point_cnt = 0 lane2other_dict = defaultdict(list) for cur_data in map_features: cur_info = {'id': cur_data.id} if cur_data.lane.ByteSize() > 0: cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph cur_info['type'] = cur_data.lane.type + 1 # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors] cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors] cur_info['interpolating'] = cur_data.lane.interpolating cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes) cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes) cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries] cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries] cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries] cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries] cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries] cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries] cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries] cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries] lane2other_dict[cur_data.id].extend(cur_info['left_boundary']) lane2other_dict[cur_data.id].extend(cur_info['right_boundary']) global_type = cur_info['type'] cur_polyline = np.stack( [np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline], axis=0) cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) if cur_polyline.shape[0] <= 1: continue map_infos['lane'].append(cur_info) map_infos['lane_dict'][cur_data.id] = cur_info elif cur_data.road_line.ByteSize() > 0: cur_info['type'] = cur_data.road_line.type + 5 global_type = cur_info['type'] cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.road_line.polyline], axis=0) cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) if cur_polyline.shape[0] <= 1: continue map_infos['road_line'].append(cur_info) elif cur_data.road_edge.ByteSize() > 0: cur_info['type'] = cur_data.road_edge.type + 14 global_type = cur_info['type'] cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.road_edge.polyline], axis=0) cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) if cur_polyline.shape[0] <= 1: continue map_infos['road_edge'].append(cur_info) elif cur_data.stop_sign.ByteSize() > 0: cur_info['lane_ids'] = list(cur_data.stop_sign.lane) for i in cur_info['lane_ids']: lane2other_dict[i].append(cur_data.id) point = cur_data.stop_sign.position cur_info['position'] = np.array([point.x, point.y, point.z]) global_type = polyline_type['TYPE_STOP_SIGN'] cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5) if cur_polyline.shape[0] <= 1: continue map_infos['stop_sign'].append(cur_info) elif cur_data.crosswalk.ByteSize() > 0: global_type = polyline_type['TYPE_CROSSWALK'] cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.crosswalk.polygon], axis=0) cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) if cur_polyline.shape[0] <= 1: continue map_infos['crosswalk'].append(cur_info) elif cur_data.speed_bump.ByteSize() > 0: global_type = polyline_type['TYPE_SPEED_BUMP'] cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.speed_bump.polygon], axis=0) cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1) if cur_polyline.shape[0] <= 1: continue map_infos['speed_bump'].append(cur_info) else: # print(cur_data) continue polylines.append(cur_polyline) cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline)) point_cnt += len(cur_polyline) # try: polylines = np.concatenate(polylines, axis=0).astype(np.float32) # except: # polylines = np.zeros((0, 8), dtype=np.float32) # print('Empty polylines: ') map_infos['all_polylines'] = polylines map_infos['lane2other_dict'] = lane2other_dict return map_infos def decode_dynamic_map_states_from_proto(dynamic_map_states): dynamic_map_infos = { 'lane_id': [], 'state': [], 'stop_point': [] } for cur_data in dynamic_map_states: # (num_timestamp) lane_id, state, stop_point = [], [], [] for cur_signal in cur_data.lane_states: # (num_observed_signals) lane_id.append(cur_signal.lane) state.append(signal_state[cur_signal.state]) stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z]) dynamic_map_infos['lane_id'].append(np.array([lane_id])) dynamic_map_infos['state'].append(np.array([state])) dynamic_map_infos['stop_point'].append(np.array([stop_point])) return dynamic_map_infos def process_single_data(scenario): info = {} info['scenario_id'] = scenario.scenario_id info['timestamps_seconds'] = list(scenario.timestamps_seconds) # list of int of shape (91) info['current_time_index'] = scenario.current_time_index # int, 10 info['sdc_track_index'] = scenario.sdc_track_index # int info['objects_of_interest'] = list(scenario.objects_of_interest) # list, could be empty list info['tracks_to_predict'] = { 'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict], 'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict] } # for training: suggestion of objects to train on, for val/test: need to be predicted track_infos = decode_tracks_from_proto(scenario.tracks) info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in info['tracks_to_predict']['track_index']] # decode map related data map_infos = decode_map_features_from_proto(scenario.map_features) dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states) save_infos = { 'track_infos': track_infos, 'dynamic_map_infos': dynamic_map_infos, 'map_infos': map_infos } save_infos.update(info) return save_infos import tensorflow as tf from waymo_open_dataset.protos import scenario_pb2 def wm2argo(file, dir_name, output_dir): file_path = os.path.join(dir_name, file) dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3) for cnt, data in enumerate(dataset): print(cnt) scenario = scenario_pb2.Scenario() scenario.ParseFromString(bytearray(data.numpy())) save_infos = process_single_data(scenario) # pkl2mtr map_info = save_infos["map_infos"] track_info = save_infos['track_infos'] scenario_id = save_infos['scenario_id'] tracks_to_predict = save_infos['tracks_to_predict'] sdc_track_index = save_infos['sdc_track_index'] av_id = track_info["object_id"][sdc_track_index] if len(tracks_to_predict["track_index"]) < 1: return dynamic_map_infos = save_infos["dynamic_map_infos"] tf_lights = process_dynamic_map(dynamic_map_infos) tf_current_light = tf_lights.loc[tf_lights["time_step"] == "11"] map_data = get_map_features(map_info, tf_current_light) new_agents_array = process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, 0, 91) # mtr2argo data = dict() data['scenario_id'] = new_agents_array['scenario_id'].values[0] data['city'] = new_agents_array['city'].values[0] data['agent'] = get_agent_features(new_agents_array, av_id, num_historical_steps=11) data.update(map_data) with open(os.path.join(output_dir, scenario_id + '.pkl'), "wb+") as f: pickle.dump(data, f) def batch_process9s_transformer(dir_name, output_dir, num_workers=2): from functools import partial import multiprocessing packages = os.listdir(dir_name) func = partial( wm2argo, output_dir=output_dir, dir_name=dir_name) with multiprocessing.Pool(num_workers) as p: list(tqdm(p.imap(func, packages), total=len(packages))) from argparse import ArgumentParser if __name__ == "__main__": parser = ArgumentParser() parser.add_argument('--input_dir', type=str, default='data/waymo/scenario/training') parser.add_argument('--output_dir', type=str, default='data/waymo_processed/training') args = parser.parse_args() files = os.listdir(args.input_dir) for file in tqdm(files): wm2argo(file, args.input_dir, args.output_dir) # batch_process9s_transformer(args.input_dir, args.output_dir, num_workers="ur_cpu_count") ================================================ FILE: environment.yml ================================================ name: smart channels: - pytorch - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu - blas=1.0=mkl - brotli-python=1.0.9=py39h6a678d5_8 - bzip2=1.0.8=h5eee18b_6 - ca-certificates=2024.9.24=h06a4308_0 - certifi=2024.8.30=py39h06a4308_0 - charset-normalizer=3.3.2=pyhd3eb1b0_0 - cudatoolkit=11.3.1=h2bc3f7f_2 - ffmpeg=4.3=hf484d3e_0 - freetype=2.12.1=h4a9f257_0 - gmp=6.2.1=h295c915_3 - gnutls=3.6.15=he1e5248_0 - idna=3.7=py39h06a4308_0 - intel-openmp=2023.1.0=hdb19cb5_46306 - jpeg=9e=h5eee18b_3 - lame=3.100=h7b6447c_0 - lcms2=2.12=h3be6417_0 - ld_impl_linux-64=2.40=h12ee557_0 - lerc=3.0=h295c915_0 - libdeflate=1.17=h5eee18b_1 - libffi=3.4.4=h6a678d5_1 - libgcc-ng=11.2.0=h1234567_1 - libgomp=11.2.0=h1234567_1 - libiconv=1.14=0 - libidn2=2.3.4=h5eee18b_0 - libpng=1.6.39=h5eee18b_0 - libstdcxx-ng=11.2.0=h1234567_1 - libtasn1=4.19.0=h5eee18b_0 - libtiff=4.5.1=h6a678d5_0 - libunistring=0.9.10=h27cfd23_0 - libwebp-base=1.3.2=h5eee18b_1 - lz4-c=1.9.4=h6a678d5_1 - mkl=2023.1.0=h213fc3f_46344 - mkl-service=2.4.0=py39h5eee18b_1 - mkl_fft=1.3.10=py39h5eee18b_0 - mkl_random=1.2.7=py39h1128e8f_0 - ncurses=6.4=h6a678d5_0 - nettle=3.7.3=hbbd107a_1 - openh264=2.1.1=h4ff587b_0 - openjpeg=2.5.2=he7f1fd0_0 - openssl=3.0.15=h5eee18b_0 - pillow=10.4.0=py39h5eee18b_0 - pip=24.2=py39h06a4308_0 - pysocks=1.7.1=py39h06a4308_0 - python=3.9.19=h955ad1f_1 - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 - pytorch-mutex=1.0=cuda - readline=8.2=h5eee18b_0 - requests=2.32.3=py39h06a4308_0 - setuptools=75.1.0=py39h06a4308_0 - sqlite=3.45.3=h5eee18b_0 - tbb=2021.8.0=hdb19cb5_0 - tk=8.6.14=h39e8969_0 - torchvision=0.13.1=py39_cu113 - typing_extensions=4.11.0=py39h06a4308_0 - urllib3=2.2.3=py39h06a4308_0 - wheel=0.44.0=py39h06a4308_0 - xz=5.4.6=h5eee18b_1 - zlib=1.2.13=h5eee18b_1 - zstd=1.5.6=hc292b87_0 ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=42", "wheel"] build-backend = "setuptools.build_meta" [project] name = "smart" version = "0.0.0" description = "Scalable Multi-agent Real-time Motion Generation via Next-token Prediction" readme = "README.md" authors = [ {name = "Xiaoxin Feng"}, {name = "Ziyan Gao"}, {name = "Yuheng Kan"} ] classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", ] requires-python = ">=3.9" dependencies = [ "easydict", "numpy", "pandas", "pytorch-lightning", "scipy", "torch-cluster", "torch-geometric", "torch-scatter", "torch", "torchmetrics", "tqdm", ] [project.urls] "Homepage" = "https://smart-motion.github.io/smart/" "Repository" = "https://github.com/rainmaker22/SMART" "Paper" = "https://arxiv.org/abs/2405.15677" [tool.setuptools] packages = ["smart"] ================================================ FILE: requirements.txt ================================================ aiohappyeyeballs==2.4.3 aiohttp==3.10.10 aiosignal==1.3.1 async-timeout==4.0.3 attrs==24.2.0 contourpy==1.3.0 cycler==0.12.1 easydict==1.13 fonttools==4.54.1 frozenlist==1.4.1 fsspec==2024.10.0 importlib-resources==6.4.5 jinja2==3.1.4 kiwisolver==1.4.7 lightning-utilities==0.11.8 markupsafe==3.0.2 matplotlib==3.9.2 multidict==6.1.0 numpy==1.26.4 packaging==24.1 pandas==2.0.3 propcache==0.2.0 psutil==6.1.0 pyparsing==3.2.0 python-dateutil==2.9.0.post0 pytorch-lightning==2.0.3 pytz==2024.2 pyyaml==6.0.1 scipy==1.10.1 shapely==2.0.6 six==1.16.0 torch-cluster==1.6.0+pt112cu113 torch-geometric==2.6.1 torch-scatter==2.1.0+pt112cu113 torch-sparse==0.6.16+pt112cu113 torch-spline-conv==1.2.1+pt112cu113 torchmetrics==1.5.0 tqdm==4.66.5 tzdata==2024.2 yarl==1.16.0 zipp==3.20.2 waymo-open-dataset-tf-2-12-0==1.6.4 ================================================ FILE: scripts/install_pyg.sh ================================================ mkdir pyg_depend && cd pyg_depend wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.16%2Bpt112cu113-cp39-cp39-linux_x86_64.whl wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1%2Bpt112cu113-cp39-cp39-linux_x86_64.whl python3 -m pip install torch_cluster-1.6.0+pt112cu113-cp39-cp39-linux_x86_64.whl python3 -m pip install torch_scatter-2.1.0+pt112cu113-cp39-cp39-linux_x86_64.whl python3 -m pip install torch_sparse-0.6.16+pt112cu113-cp39-cp39-linux_x86_64.whl python3 -m pip install torch_spline_conv-1.2.1+pt112cu113-cp39-cp39-linux_x86_64.whl python3 -m pip install torch_geometric ================================================ FILE: scripts/traj_clstering.py ================================================ from smart.utils.geometry import wrap_angle import numpy as np def average_distance_vectorized(point_set1, centroids): dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :])**2, axis=-1)) return np.mean(dists, axis=2) def assign_clusters(sub_X, centroids): distances = average_distance_vectorized(sub_X, centroids) return np.argmin(distances, axis=1) def Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None): S = [] ret_traj_list = [] while len(S) < N: num_all = X.shape[0] # 随机选择第一个簇中心 choice_index = np.random.choice(num_all) x0 = X[choice_index] if x0[0, 0] < -10 or x0[0, 0] > 50 or x0[0, 1] > 10 or x0[0, 1] < -10: continue res_mask = np.sum((X - x0)**2, axis=(1, 2))/4 > (tol**2) del_mask = np.sum((X - x0)**2, axis=(1, 2))/4 <= (tol**2) if cal_mean_heading: del_contour = X[del_mask] diff_xy = del_contour[:, 0, :] - del_contour[:, 3, :] del_heading = np.arctan2(diff_xy[:, 1], diff_xy[:, 0]).mean() x0 = cal_polygon_contour(x0.mean(0)[0], x0.mean(0)[1], del_heading, width, length) del_traj = a_pos[del_mask] ret_traj = del_traj.mean(0)[None, ...] if abs(ret_traj[0, 1, 0] - ret_traj[0, 0, 0]) > 1 and ret_traj[0, 1, 0] < 0: print(ret_traj) print('1') else: x0 = x0[None, ...] ret_traj = a_pos[choice_index][None, ...] X = X[res_mask] a_pos = a_pos[res_mask] S.append(x0) ret_traj_list.append(ret_traj) centroids = np.concatenate(S, axis=0) ret_traj = np.concatenate(ret_traj_list, axis=0) # closest_dist_sq = np.sum((X - centroids[0])**2, axis=(1, 2)) # for k in range(1, K): # new_dist_sq = np.sum((X - centroids[k - 1])**2, axis=(1, 2)) # closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq) # probabilities = closest_dist_sq / np.sum(closest_dist_sq) # centroids[k] = X[np.random.choice(N, p=probabilities)] return centroids, ret_traj def cal_polygon_contour(x, y, theta, width, length): left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) left_front = np.column_stack((left_front_x, left_front_y)) right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) right_front = np.column_stack((right_front_x, right_front_y)) right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) right_back = np.column_stack((right_back_x, right_back_y)) left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) left_back = np.column_stack((left_back_x, left_back_y)) polygon_contour = np.concatenate((left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1) return polygon_contour if __name__ == '__main__': shift = 5 # motion token time dimension num_cluster = 6 # vocabulary size cal_mean_heading = True data = { "veh": np.random.rand(1000, 6, 3), "cyc": np.random.rand(1000, 6, 3), "ped": np.random.rand(1000, 6, 3) } # Collect the trajectories of all traffic participants from the raw data [NumAgent, shift+1, [relative_x, relative_y, relative_theta]] nms_res = {} res = {'token': {}, 'traj': {}, 'token_all': {}} for k, v in data.items(): # if k != 'veh': # continue a_pos = v print(a_pos.shape) # a_pos = a_pos[:, shift:1+shift, :] cal_num = min(int(1e6), a_pos.shape[0]) a_pos = a_pos[np.random.choice(a_pos.shape[0], cal_num, replace=False)] a_pos[:, :, -1] = wrap_angle(a_pos[:, :, -1]) print(a_pos.shape) if shift <= 2: if k == 'veh': width = 1.0 length = 2.4 elif k == 'cyc': width = 0.5 length = 1.5 else: width = 0.5 length = 0.5 else: if k == 'veh': width = 2.0 length = 4.8 elif k == 'cyc': width = 1.0 length = 2.0 else: width = 1.0 length = 1.0 contour = cal_polygon_contour(a_pos[:, shift, 0], a_pos[:, shift, 1], a_pos[:, shift, 2], width, length) # plt.figure(figsize=(10, 10)) # for rect in contour: # rect_closed = np.vstack([rect, rect[0]]) # plt.plot(rect_closed[:, 0], rect_closed[:, 1], linewidth=0.1) # plt.title("Plot of 256 Rectangles") # plt.xlabel("x") # plt.ylabel("y") # plt.axis('equal') # plt.savefig(f'src_{k}_new.jpg', dpi=300) if k == 'veh': tol = 0.05 elif k == 'cyc': tol = 0.004 else: tol = 0.004 centroids, ret_traj = Kdisk_cluster(contour, num_cluster, tol, width, length, a_pos[:, :shift+1]) # plt.figure(figsize=(10, 10)) contour = cal_polygon_contour(ret_traj[:, :, 0].reshape(num_cluster*(shift+1)), ret_traj[:, :, 1].reshape(num_cluster*(shift+1)), ret_traj[:, :, 2].reshape(num_cluster*(shift+1)), width, length) res['token_all'][k] = contour.reshape(num_cluster, (shift+1), 4, 2) res['token'][k] = centroids res['traj'][k] = ret_traj ================================================ FILE: smart/__init__.py ================================================ ================================================ FILE: smart/datamodules/__init__.py ================================================ from smart.datamodules.scalable_datamodule import MultiDataModule ================================================ FILE: smart/datamodules/scalable_datamodule.py ================================================ from typing import Optional import pytorch_lightning as pl from torch_geometric.loader import DataLoader from smart.datasets.scalable_dataset import MultiDataset from smart.transforms import WaymoTargetBuilder class MultiDataModule(pl.LightningDataModule): transforms = { "WaymoTargetBuilder": WaymoTargetBuilder, } dataset = { "scalable": MultiDataset, } def __init__(self, root: str, train_batch_size: int, val_batch_size: int, test_batch_size: int, shuffle: bool = False, num_workers: int = 0, pin_memory: bool = True, persistent_workers: bool = True, train_raw_dir: Optional[str] = None, val_raw_dir: Optional[str] = None, test_raw_dir: Optional[str] = None, train_processed_dir: Optional[str] = None, val_processed_dir: Optional[str] = None, test_processed_dir: Optional[str] = None, transform: Optional[str] = None, dataset: Optional[str] = None, num_historical_steps: int = 50, num_future_steps: int = 60, processor='ntp', use_intention=False, token_size=512, **kwargs) -> None: super(MultiDataModule, self).__init__() self.root = root self.dataset_class = dataset self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size self.test_batch_size = test_batch_size self.shuffle = shuffle self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers and num_workers > 0 self.train_raw_dir = train_raw_dir self.val_raw_dir = val_raw_dir self.test_raw_dir = test_raw_dir self.train_processed_dir = train_processed_dir self.val_processed_dir = val_processed_dir self.test_processed_dir = test_processed_dir self.processor = processor self.use_intention = use_intention self.token_size = token_size train_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "train") val_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "val") test_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps) self.train_transform = train_transform self.val_transform = val_transform self.test_transform = test_transform def setup(self, stage: Optional[str] = None) -> None: self.train_dataset = MultiDataModule.dataset[self.dataset_class](self.root, 'train', processed_dir=self.train_processed_dir, raw_dir=self.train_raw_dir, processor=self.processor, transform=self.train_transform, token_size=self.token_size) self.val_dataset = MultiDataModule.dataset[self.dataset_class](None, 'val', processed_dir=self.val_processed_dir, raw_dir=self.val_raw_dir, processor=self.processor, transform=self.val_transform, token_size=self.token_size) self.test_dataset = MultiDataModule.dataset[self.dataset_class](None, 'test', processed_dir=self.test_processed_dir, raw_dir=self.test_raw_dir, processor=self.processor, transform=self.test_transform, token_size=self.token_size) def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers) def test_dataloader(self): return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers) ================================================ FILE: smart/datasets/__init__.py ================================================ from smart.datasets.scalable_dataset import MultiDataset ================================================ FILE: smart/datasets/preprocess.py ================================================ import torch import numpy as np from scipy.interpolate import interp1d from scipy.spatial.distance import euclidean import math import pickle from smart.utils import wrap_angle import os def cal_polygon_contour(x, y, theta, width, length): left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) left_front = np.column_stack((left_front_x, left_front_y)) right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) right_front = np.column_stack((right_front_x, right_front_y)) right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta) right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta) right_back = np.column_stack((right_back_x, right_back_y)) left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta) left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta) left_back = np.column_stack((left_back_x, left_back_y)) polygon_contour = np.concatenate( (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1) return polygon_contour def interplating_polyline(polylines, heading, distance=0.5, split_distace=5): # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter dist_along_path_list = [[0]] polylines_list = [[polylines[0]]] for i in range(1, polylines.shape[0]): euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2]) heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])), abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi)) if heading_diff > math.pi / 4 and euclidean_dist > 3: dist_along_path_list.append([0]) polylines_list.append([polylines[i]]) elif heading_diff > math.pi / 8 and euclidean_dist > 3: dist_along_path_list.append([0]) polylines_list.append([polylines[i]]) elif heading_diff > 0.1 and euclidean_dist > 3: dist_along_path_list.append([0]) polylines_list.append([polylines[i]]) elif euclidean_dist > 10: dist_along_path_list.append([0]) polylines_list.append([polylines[i]]) else: dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist) polylines_list[-1].append(polylines[i]) # plt.plot(polylines[:, 0], polylines[:, 1]) # plt.savefig('tmp.jpg') new_x_list = [] new_y_list = [] multi_polylines_list = [] for idx in range(len(dist_along_path_list)): if len(dist_along_path_list[idx]) < 2: continue dist_along_path = np.array(dist_along_path_list[idx]) polylines_cur = np.array(polylines_list[idx]) # Create interpolation functions for x and y coordinates fx = interp1d(dist_along_path, polylines_cur[:, 0]) fy = interp1d(dist_along_path, polylines_cur[:, 1]) # fyaw = interp1d(dist_along_path, heading) # Create an array of distances at which to interpolate new_dist_along_path = np.arange(0, dist_along_path[-1], distance) new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]]) # Use the interpolation functions to generate new x and y coordinates new_x = fx(new_dist_along_path) new_y = fy(new_dist_along_path) # new_yaw = fyaw(new_dist_along_path) new_x_list.append(new_x) new_y_list.append(new_y) # Combine the new x and y coordinates into a single array new_polylines = np.vstack((new_x, new_y)).T polyline_size = int(split_distace / distance) if new_polylines.shape[0] >= (polyline_size + 1): padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1 else: padding_size = new_polylines.shape[0] final_index = 0 multi_polylines = None new_polylines = torch.from_numpy(new_polylines) new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1], new_polylines[1:, 0] - new_polylines[:-1, 0]) new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None] new_polylines = torch.cat([new_polylines, new_heading], -1) if new_polylines.shape[0] >= (polyline_size + 1): multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size) multi_polylines = multi_polylines.transpose(1, 2) multi_polylines = multi_polylines[:, ::5, :] if padding_size >= 3: last_polyline = new_polylines[final_index * polyline_size:] last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()] if multi_polylines is not None: multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0) else: multi_polylines = last_polyline.unsqueeze(0) if multi_polylines is None: continue multi_polylines_list.append(multi_polylines) if len(multi_polylines_list) > 0: multi_polylines_list = torch.cat(multi_polylines_list, dim=0) else: multi_polylines_list = None return multi_polylines_list def average_distance_vectorized(point_set1, centroids): dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1)) return np.mean(dists, axis=2) def assign_clusters(sub_X, centroids): distances = average_distance_vectorized(sub_X, centroids) return np.argmin(distances, axis=1) class TokenProcessor: def __init__(self, token_size): module_dir = os.path.dirname(os.path.dirname(__file__)) self.agent_token_path = os.path.join(module_dir, f'tokens/cluster_frame_5_{token_size}.pkl') self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl') self.noise = False self.disturb = False self.shift = 5 self.get_trajectory_token() self.training = False self.current_step = 10 def preprocess(self, data): data = self.tokenize_agent(data) data = self.tokenize_map(data) del data['city'] if 'polygon_is_intersection' in data['map_polygon']: del data['map_polygon']['polygon_is_intersection'] if 'route_type' in data['map_polygon']: del data['map_polygon']['route_type'] return data def get_trajectory_token(self): agent_token_data = pickle.load(open(self.agent_token_path, 'rb')) map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb')) self.trajectory_token = agent_token_data['token'] self.trajectory_token_all = agent_token_data['token_all'] self.map_token = {'traj_src': map_token_traj['traj_src'], } self.token_last = {} for k, v in self.trajectory_token_all.items(): token_last = torch.from_numpy(v[:, -2:]).to(torch.float) diff_xy = token_last[:, 0, 0] - token_last[:, 0, 3] theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0]) cos, sin = theta.cos(), theta.sin() rot_mat = theta.new_zeros(token_last.shape[0], 2, 2) rot_mat[:, 0, 0] = cos rot_mat[:, 0, 1] = -sin rot_mat[:, 1, 0] = sin rot_mat[:, 1, 1] = cos agent_token = torch.bmm(token_last[:, 1], rot_mat) agent_token -= token_last[:, 0].mean(1)[:, None, :] self.token_last[k] = agent_token.numpy() def clean_heading(self, data): heading = data['agent']['heading'] valid = data['agent']['valid_mask'] pi = torch.tensor(torch.pi) n_vehicles, n_frames = heading.shape heading_diff_raw = heading[:, :-1] - heading[:, 1:] heading_diff = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi heading_diff[heading_diff > pi] -= 2 * pi heading_diff[heading_diff < -pi] += 2 * pi valid_pairs = valid[:, :-1] & valid[:, 1:] for i in range(n_frames - 1): change_needed = (torch.abs(heading_diff[:, i:i + 1]) > 1.0) & valid_pairs[:, i:i + 1] heading[:, i + 1][change_needed.squeeze()] = heading[:, i][change_needed.squeeze()] if i < n_frames - 2: heading_diff_raw = heading[:, i + 1] - heading[:, i + 2] heading_diff[:, i + 1] = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi heading_diff[heading_diff[:, i + 1] > pi] -= 2 * pi heading_diff[heading_diff[:, i + 1] < -pi] += 2 * pi def tokenize_agent(self, data): if data['agent']["velocity"].shape[1] == 90: print(data['scenario_id'], data['agent']["velocity"].shape) interplote_mask = (data['agent']['valid_mask'][:, self.current_step] == False) * ( data['agent']['position'][:, self.current_step, 0] != 0) if data['agent']["velocity"].shape[-1] == 2: data['agent']["velocity"] = torch.cat([data['agent']["velocity"], torch.zeros(data['agent']["velocity"].shape[0], data['agent']["velocity"].shape[1], 1)], dim=-1) vel = data['agent']["velocity"][interplote_mask, self.current_step] data['agent']['position'][interplote_mask, self.current_step - 1, :3] = data['agent']['position'][ interplote_mask, self.current_step, :3] - vel * 0.1 data['agent']['valid_mask'][interplote_mask, self.current_step - 1:self.current_step + 1] = True data['agent']['heading'][interplote_mask, self.current_step - 1] = data['agent']['heading'][ interplote_mask, self.current_step] data['agent']["velocity"][interplote_mask, self.current_step - 1] = data['agent']["velocity"][ interplote_mask, self.current_step] data['agent']['type'] = data['agent']['type'].to(torch.uint8) self.clean_heading(data) matching_extra_mask = (data['agent']['valid_mask'][:, self.current_step] == True) * ( data['agent']['valid_mask'][:, self.current_step - 5] == False) interplote_mask_first = (data['agent']['valid_mask'][:, 0] == False) * (data['agent']['position'][:, 0, 0] != 0) data['agent']['valid_mask'][interplote_mask_first, 0] = True agent_pos = data['agent']['position'][:, :, :2] valid_mask = data['agent']['valid_mask'] valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift) token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1] agent_type = data['agent']['type'] agent_category = data['agent']['category'] agent_heading = data['agent']['heading'] vehicle_mask = agent_type == 0 cyclist_mask = agent_type == 2 ped_mask = agent_type == 1 veh_pos = agent_pos[vehicle_mask, :, :] veh_valid_mask = valid_mask[vehicle_mask, :] cyc_pos = agent_pos[cyclist_mask, :, :] cyc_valid_mask = valid_mask[cyclist_mask, :] ped_pos = agent_pos[ped_mask, :, :] ped_valid_mask = valid_mask[ped_mask, :] veh_token_index, veh_token_contour = self.match_token(veh_pos, veh_valid_mask, agent_heading[vehicle_mask], 'veh', agent_category[vehicle_mask], matching_extra_mask[vehicle_mask]) ped_token_index, ped_token_contour = self.match_token(ped_pos, ped_valid_mask, agent_heading[ped_mask], 'ped', agent_category[ped_mask], matching_extra_mask[ped_mask]) cyc_token_index, cyc_token_contour = self.match_token(cyc_pos, cyc_valid_mask, agent_heading[cyclist_mask], 'cyc', agent_category[cyclist_mask], matching_extra_mask[cyclist_mask]) token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64) token_index[vehicle_mask] = veh_token_index token_index[ped_mask] = ped_token_index token_index[cyclist_mask] = cyc_token_index token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1], veh_token_contour.shape[2], veh_token_contour.shape[3])) token_contour[vehicle_mask] = veh_token_contour token_contour[ped_mask] = ped_token_contour token_contour[cyclist_mask] = cyc_token_contour trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(torch.float) trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(torch.float) trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(torch.float) agent_token_traj = torch.zeros((agent_pos.shape[0], trajectory_token_veh.shape[0], 4, 2)) agent_token_traj[vehicle_mask] = trajectory_token_veh agent_token_traj[ped_mask] = trajectory_token_ped agent_token_traj[cyclist_mask] = trajectory_token_cyc if not self.training: token_valid_mask[matching_extra_mask, 1] = True data['agent']['token_idx'] = token_index data['agent']['token_contour'] = token_contour token_pos = token_contour.mean(dim=2) data['agent']['token_pos'] = token_pos diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :] data['agent']['token_heading'] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) data['agent']['agent_valid_mask'] = token_valid_mask vel = torch.cat([token_pos.new_zeros(data['agent']['num_nodes'], 1, 2), ((token_pos[:, 1:] - token_pos[:, :-1]) / (0.1 * self.shift))], dim=1) vel_valid_mask = torch.cat([torch.zeros(token_valid_mask.shape[0], 1, dtype=torch.bool), (token_valid_mask * token_valid_mask.roll(shifts=1, dims=1))[:, 1:]], dim=1) vel[~vel_valid_mask] = 0 vel[data['agent']['valid_mask'][:, self.current_step], 1] = data['agent']['velocity'][ data['agent']['valid_mask'][:, self.current_step], self.current_step, :2] data['agent']['token_velocity'] = vel return data def match_token(self, pos, valid_mask, heading, category, agent_category, extra_mask): agent_token_src = self.trajectory_token[category] token_last = self.token_last[category] if self.shift <= 2: if category == 'veh': width = 1.0 length = 2.4 elif category == 'cyc': width = 0.5 length = 1.5 else: width = 0.5 length = 0.5 else: if category == 'veh': width = 2.0 length = 4.8 elif category == 'cyc': width = 1.0 length = 2.0 else: width = 1.0 length = 1.0 prev_heading = heading[:, 0] prev_pos = pos[:, 0] agent_num, num_step, feat_dim = pos.shape token_num, token_contour_dim, feat_dim = agent_token_src.shape agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0) token_last = token_last.reshape(1, token_num * token_contour_dim, feat_dim).repeat(extra_mask.sum(), 0) token_index_list = [] token_contour_list = [] prev_token_idx = None for i in range(self.shift, pos.shape[1], self.shift): theta = prev_heading cur_heading = heading[:, i] cur_pos = pos[:, i] cos, sin = theta.cos(), theta.sin() rot_mat = theta.new_zeros(agent_num, 2, 2) rot_mat[:, 0, 0] = cos rot_mat[:, 0, 1] = sin rot_mat[:, 1, 0] = -sin rot_mat[:, 1, 1] = cos agent_token_world = torch.bmm(torch.from_numpy(agent_token_src).to(torch.float), rot_mat).reshape(agent_num, token_num, token_contour_dim, feat_dim) agent_token_world += prev_pos[:, None, None, :] cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length) agent_token_index = torch.from_numpy(np.argmin( np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2), axis=-1)) if prev_token_idx is not None and self.noise: same_idx = prev_token_idx == agent_token_index same_idx[:] = True topk_indices = np.argsort( np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2), axis=-1)[:, :5] sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0]) agent_token_index[same_idx] = \ torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx] token_contour_select = agent_token_world[torch.arange(agent_num), agent_token_index] diff_xy = token_contour_select[:, 0, :] - token_contour_select[:, 3, :] prev_heading = heading[:, i].clone() prev_heading[valid_mask[:, i - self.shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[ valid_mask[:, i - self.shift]] prev_pos = pos[:, i].clone() prev_pos[valid_mask[:, i - self.shift]] = token_contour_select.mean(dim=1)[valid_mask[:, i - self.shift]] prev_token_idx = agent_token_index token_index_list.append(agent_token_index[:, None]) token_contour_list.append(token_contour_select[:, None, ...]) token_index = torch.cat(token_index_list, dim=1) token_contour = torch.cat(token_contour_list, dim=1) # extra matching if not self.training: theta = heading[extra_mask, self.current_step - 1] prev_pos = pos[extra_mask, self.current_step - 1] cur_pos = pos[extra_mask, self.current_step] cur_heading = heading[extra_mask, self.current_step] cos, sin = theta.cos(), theta.sin() rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2) rot_mat[:, 0, 0] = cos rot_mat[:, 0, 1] = sin rot_mat[:, 1, 0] = -sin rot_mat[:, 1, 1] = cos agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape( extra_mask.sum(), token_num, token_contour_dim, feat_dim) agent_token_world += prev_pos[:, None, None, :] cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length) agent_token_index = torch.from_numpy(np.argmin( np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2), axis=-1)) token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index] token_index[extra_mask, 1] = agent_token_index token_contour[extra_mask, 1] = token_contour_select return token_index, token_contour def tokenize_map(self, data): data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8) data['map_point']['type'] = data['map_point']['type'].to(torch.uint8) pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index'] pt_type = data['map_point']['type'].to(torch.uint8) pt_side = torch.zeros_like(pt_type) pt_pos = data['map_point']['position'][:, :2] data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation']) pt_heading = data['map_point']['orientation'] split_polyline_type = [] split_polyline_pos = [] split_polyline_theta = [] split_polyline_side = [] pl_idx_list = [] split_polygon_type = [] data['map_point']['type'].unique() for i in sorted(np.unique(pt2pl[1])): index = pt2pl[0, pt2pl[1] == i] polygon_type = data['map_polygon']["type"][i] cur_side = pt_side[index] cur_type = pt_type[index] cur_pos = pt_pos[index] cur_heading = pt_heading[index] for side_val in np.unique(cur_side): for type_val in np.unique(cur_type): if type_val == 13: continue indices = np.where((cur_side == side_val) & (cur_type == type_val))[0] if len(indices) <= 2: continue split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy()) if split_polyline is None: continue new_cur_type = cur_type[indices][0] new_cur_side = cur_side[indices][0] map_polygon_type = polygon_type.repeat(split_polyline.shape[0]) new_cur_type = new_cur_type.repeat(split_polyline.shape[0]) new_cur_side = new_cur_side.repeat(split_polyline.shape[0]) cur_pl_idx = torch.Tensor([i]) new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0]) split_polyline_pos.append(split_polyline[..., :2]) split_polyline_theta.append(split_polyline[..., 2]) split_polyline_type.append(new_cur_type) split_polyline_side.append(new_cur_side) pl_idx_list.append(new_cur_pl_idx) split_polygon_type.append(map_polygon_type) split_polyline_pos = torch.cat(split_polyline_pos, dim=0) split_polyline_theta = torch.cat(split_polyline_theta, dim=0) split_polyline_type = torch.cat(split_polyline_type, dim=0) split_polyline_side = torch.cat(split_polyline_side, dim=0) split_polygon_type = torch.cat(split_polygon_type, dim=0) pl_idx_list = torch.cat(pl_idx_list, dim=0) vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :] data['map_save'] = {} data['pt_token'] = {} data['map_save']['traj_pos'] = split_polyline_pos data['map_save']['traj_theta'] = split_polyline_theta[:, 0] # torch.arctan2(vec[:, 1], vec[:, 0]) data['map_save']['pl_idx_list'] = pl_idx_list data['pt_token']['type'] = split_polyline_type data['pt_token']['side'] = split_polyline_side data['pt_token']['pl_type'] = split_polygon_type data['pt_token']['num_nodes'] = split_polyline_pos.shape[0] return data ================================================ FILE: smart/datasets/scalable_dataset.py ================================================ import os import pickle from typing import Callable, List, Optional, Tuple, Union import pandas as pd from torch_geometric.data import Dataset from smart.utils.log import Logging import numpy as np from .preprocess import TokenProcessor def distance(point1, point2): return np.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2) class MultiDataset(Dataset): def __init__(self, root: str, split: str, raw_dir: List[str] = None, processed_dir: List[str] = None, transform: Optional[Callable] = None, dim: int = 3, num_historical_steps: int = 50, num_future_steps: int = 60, predict_unseen_agents: bool = False, vector_repr: bool = True, cluster: bool = False, processor=None, use_intention=False, token_size=512) -> None: self.logger = Logging().log(level='DEBUG') self.root = root self.well_done = [0] if split not in ('train', 'val', 'test'): raise ValueError(f'{split} is not a valid split') self.split = split self.training = split == 'train' self.logger.debug("Starting loading dataset") self._raw_file_names = [] self._raw_paths = [] self._raw_file_dataset = [] if raw_dir is not None: self._raw_dir = raw_dir for raw_dir in self._raw_dir: raw_dir = os.path.expanduser(os.path.normpath(raw_dir)) dataset = "waymo" file_list = os.listdir(raw_dir) self._raw_file_names.extend(file_list) self._raw_paths.extend([os.path.join(raw_dir, f) for f in file_list]) self._raw_file_dataset.extend([dataset for _ in range(len(file_list))]) if self.root is not None: split_datainfo = os.path.join(root, "split_datainfo.pkl") with open(split_datainfo, 'rb+') as f: split_datainfo = pickle.load(f) if split == "test": split = "val" self._processed_file_names = split_datainfo[split] self.dim = dim self.num_historical_steps = num_historical_steps self._num_samples = len(self._processed_file_names) - 1 if processed_dir is not None else len(self._raw_file_names) self.logger.debug("The number of {} dataset is ".format(split) + str(self._num_samples)) self.token_processor = TokenProcessor(2048) super(MultiDataset, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None) @property def raw_dir(self) -> str: return self._raw_dir @property def raw_paths(self) -> List[str]: return self._raw_paths @property def raw_file_names(self) -> Union[str, List[str], Tuple]: return self._raw_file_names @property def processed_file_names(self) -> Union[str, List[str], Tuple]: return self._processed_file_names def len(self) -> int: return self._num_samples def generate_ref_token(self): pass def get(self, idx: int): with open(self.raw_paths[idx], 'rb') as handle: data = pickle.load(handle) data = self.token_processor.preprocess(data) return data ================================================ FILE: smart/layers/__init__.py ================================================ from smart.layers.attention_layer import AttentionLayer from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding from smart.layers.mlp_layer import MLPLayer ================================================ FILE: smart/layers/attention_layer.py ================================================ from typing import Optional, Tuple, Union import torch import torch.nn as nn from torch_geometric.nn.conv import MessagePassing from torch_geometric.utils import softmax from smart.utils import weight_init class AttentionLayer(MessagePassing): def __init__(self, hidden_dim: int, num_heads: int, head_dim: int, dropout: float, bipartite: bool, has_pos_emb: bool, **kwargs) -> None: super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs) self.num_heads = num_heads self.head_dim = head_dim self.has_pos_emb = has_pos_emb self.scale = head_dim ** -0.5 self.to_q = nn.Linear(hidden_dim, head_dim * num_heads) self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) self.to_v = nn.Linear(hidden_dim, head_dim * num_heads) if has_pos_emb: self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads) self.to_s = nn.Linear(hidden_dim, head_dim * num_heads) self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads) self.to_out = nn.Linear(head_dim * num_heads, hidden_dim) self.attn_drop = nn.Dropout(dropout) self.ff_mlp = nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 4), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim), ) if bipartite: self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim) else: self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) self.attn_prenorm_x_dst = self.attn_prenorm_x_src if has_pos_emb: self.attn_prenorm_r = nn.LayerNorm(hidden_dim) self.attn_postnorm = nn.LayerNorm(hidden_dim) self.ff_prenorm = nn.LayerNorm(hidden_dim) self.ff_postnorm = nn.LayerNorm(hidden_dim) self.apply(weight_init) def forward(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], r: Optional[torch.Tensor], edge_index: torch.Tensor) -> torch.Tensor: if isinstance(x, torch.Tensor): x_src = x_dst = self.attn_prenorm_x_src(x) else: x_src, x_dst = x x_src = self.attn_prenorm_x_src(x_src) x_dst = self.attn_prenorm_x_dst(x_dst) x = x[1] if self.has_pos_emb and r is not None: r = self.attn_prenorm_r(r) x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index)) x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x))) return x def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, r: Optional[torch.Tensor], index: torch.Tensor, ptr: Optional[torch.Tensor]) -> torch.Tensor: if self.has_pos_emb and r is not None: k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim) v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim) sim = (q_i * k_j).sum(dim=-1) * self.scale attn = softmax(sim, index, ptr) self.attention_weight = attn.sum(-1).detach() attn = self.attn_drop(attn) return v_j * attn.unsqueeze(-1) def update(self, inputs: torch.Tensor, x_dst: torch.Tensor) -> torch.Tensor: inputs = inputs.view(-1, self.num_heads * self.head_dim) g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1))) return inputs + g * (self.to_s(x_dst) - inputs) def _attn_block(self, x_src: torch.Tensor, x_dst: torch.Tensor, r: Optional[torch.Tensor], edge_index: torch.Tensor) -> torch.Tensor: q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim) k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim) v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim) agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r) return self.to_out(agg) def _ff_block(self, x: torch.Tensor) -> torch.Tensor: return self.ff_mlp(x) ================================================ FILE: smart/layers/fourier_embedding.py ================================================ import math from typing import List, Optional import torch import torch.nn as nn from smart.utils import weight_init class FourierEmbedding(nn.Module): def __init__(self, input_dim: int, hidden_dim: int, num_freq_bands: int) -> None: super(FourierEmbedding, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None self.mlps = nn.ModuleList( [nn.Sequential( nn.Linear(num_freq_bands * 2 + 1, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), ) for _ in range(input_dim)]) self.to_out = nn.Sequential( nn.LayerNorm(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim), ) self.apply(weight_init) def forward(self, continuous_inputs: Optional[torch.Tensor] = None, categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: if continuous_inputs is None: if categorical_embs is not None: x = torch.stack(categorical_embs).sum(dim=0) else: raise ValueError('Both continuous_inputs and categorical_embs are None') else: x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi # Warning: if your data are noisy, don't use learnable sinusoidal embedding x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim for i in range(self.input_dim): continuous_embs[i] = self.mlps[i](x[:, i]) x = torch.stack(continuous_embs).sum(dim=0) if categorical_embs is not None: x = x + torch.stack(categorical_embs).sum(dim=0) return self.to_out(x) class MLPEmbedding(nn.Module): def __init__(self, input_dim: int, hidden_dim: int) -> None: super(MLPEmbedding, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.mlp = nn.Sequential( nn.Linear(input_dim, 128), nn.LayerNorm(128), nn.ReLU(inplace=True), nn.Linear(128, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim)) self.apply(weight_init) def forward(self, continuous_inputs: Optional[torch.Tensor] = None, categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: if continuous_inputs is None: if categorical_embs is not None: x = torch.stack(categorical_embs).sum(dim=0) else: raise ValueError('Both continuous_inputs and categorical_embs are None') else: x = self.mlp(continuous_inputs) if categorical_embs is not None: x = x + torch.stack(categorical_embs).sum(dim=0) return x ================================================ FILE: smart/layers/mlp_layer.py ================================================ import torch import torch.nn as nn from smart.utils import weight_init class MLPLayer(nn.Module): def __init__(self, input_dim: int, hidden_dim: int, output_dim: int) -> None: super(MLPLayer, self).__init__() self.mlp = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, output_dim), ) self.apply(weight_init) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) ================================================ FILE: smart/metrics/__init__.py ================================================ from smart.metrics.average_meter import AverageMeter from smart.metrics.min_ade import minADE from smart.metrics.min_fde import minFDE from smart.metrics.next_token_cls import TokenCls ================================================ FILE: smart/metrics/average_meter.py ================================================ import torch from torchmetrics import Metric class AverageMeter(Metric): def __init__(self, **kwargs) -> None: super(AverageMeter, self).__init__(**kwargs) self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') def update(self, val: torch.Tensor) -> None: self.sum += val.sum() self.count += val.numel() def compute(self) -> torch.Tensor: return self.sum / self.count ================================================ FILE: smart/metrics/min_ade.py ================================================ from typing import Optional import torch from torchmetrics import Metric from smart.metrics.utils import topk from smart.metrics.utils import valid_filter class minMultiADE(Metric): def __init__(self, max_guesses: int = 6, **kwargs) -> None: super(minMultiADE, self).__init__(**kwargs) self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') self.max_guesses = max_guesses def update(self, pred: torch.Tensor, target: torch.Tensor, prob: Optional[torch.Tensor] = None, valid_mask: Optional[torch.Tensor] = None, keep_invalid_final_step: bool = True, min_criterion: str = 'FDE') -> None: pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) pred_topk, _ = topk(self.max_guesses, pred, prob) if min_criterion == 'FDE': inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) inds_best = torch.norm( pred_topk[torch.arange(pred.size(0)), :, inds_last] - target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1) self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) * valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum() elif min_criterion == 'ADE': self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) * valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum() else: raise ValueError('{} is not a valid criterion'.format(min_criterion)) self.count += pred.size(0) def compute(self) -> torch.Tensor: return self.sum / self.count class minADE(Metric): def __init__(self, max_guesses: int = 6, **kwargs) -> None: super(minADE, self).__init__(**kwargs) self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') self.max_guesses = max_guesses self.eval_timestep = 70 def update(self, pred: torch.Tensor, target: torch.Tensor, prob: Optional[torch.Tensor] = None, valid_mask: Optional[torch.Tensor] = None, keep_invalid_final_step: bool = True, min_criterion: str = 'ADE') -> None: # pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) # pred_topk, _ = topk(self.max_guesses, pred, prob) # if min_criterion == 'FDE': # inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) # inds_best = torch.norm( # pred[torch.arange(pred.size(0)), :, inds_last] - # target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1) # self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) * # valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum() # elif min_criterion == 'ADE': # self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) * # valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum() # else: # raise ValueError('{} is not a valid criterion'.format(min_criterion)) eval_timestep = min(self.eval_timestep, pred.shape[1]) self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum() self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum() def compute(self) -> torch.Tensor: return self.sum / self.count ================================================ FILE: smart/metrics/min_fde.py ================================================ from typing import Optional import torch from torchmetrics import Metric from smart.metrics.utils import topk from smart.metrics.utils import valid_filter class minMultiFDE(Metric): def __init__(self, max_guesses: int = 6, **kwargs) -> None: super(minMultiFDE, self).__init__(**kwargs) self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') self.max_guesses = max_guesses def update(self, pred: torch.Tensor, target: torch.Tensor, prob: Optional[torch.Tensor] = None, valid_mask: Optional[torch.Tensor] = None, keep_invalid_final_step: bool = True) -> None: pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) pred_topk, _ = topk(self.max_guesses, pred, prob) inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] - target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).min(dim=-1)[0].sum() self.count += pred.size(0) def compute(self) -> torch.Tensor: return self.sum / self.count class minFDE(Metric): def __init__(self, max_guesses: int = 6, **kwargs) -> None: super(minFDE, self).__init__(**kwargs) self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') self.max_guesses = max_guesses self.eval_timestep = 70 def update(self, pred: torch.Tensor, target: torch.Tensor, prob: Optional[torch.Tensor] = None, valid_mask: Optional[torch.Tensor] = None, keep_invalid_final_step: bool = True) -> None: eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1 self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) * valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum() self.count += valid_mask[:, eval_timestep-1].sum() def compute(self) -> torch.Tensor: return self.sum / self.count ================================================ FILE: smart/metrics/next_token_cls.py ================================================ from typing import Optional import torch from torchmetrics import Metric from smart.metrics.utils import topk from smart.metrics.utils import valid_filter class TokenCls(Metric): def __init__(self, max_guesses: int = 6, **kwargs) -> None: super(TokenCls, self).__init__(**kwargs) self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') self.max_guesses = max_guesses def update(self, pred: torch.Tensor, target: torch.Tensor, valid_mask: Optional[torch.Tensor] = None) -> None: target = target[..., None] acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask self.sum += acc.sum() self.count += valid_mask.sum() def compute(self) -> torch.Tensor: return self.sum / self.count ================================================ FILE: smart/metrics/utils.py ================================================ from typing import Optional, Tuple import torch from torch_scatter import gather_csr from torch_scatter import segment_csr def topk( max_guesses: int, pred: torch.Tensor, prob: Optional[torch.Tensor] = None, ptr: Optional[torch.Tensor] = None, joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: max_guesses = min(max_guesses, pred.size(1)) if max_guesses == pred.size(1): if prob is not None: prob = prob / prob.sum(dim=-1, keepdim=True) else: prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses return pred, prob else: if prob is not None: if joint: if ptr is None: inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), k=max_guesses, dim=-1, largest=True, sorted=True)[1] inds_topk = inds_topk.repeat(pred.size(0), 1) else: inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, reduce='mean'), k=max_guesses, dim=-1, largest=True, sorted=True)[1] inds_topk = gather_csr(src=inds_topk, indptr=ptr) else: inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) else: pred_topk = pred[:, :max_guesses] prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses return pred_topk, prob_topk def topkind( max_guesses: int, pred: torch.Tensor, prob: Optional[torch.Tensor] = None, ptr: Optional[torch.Tensor] = None, joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: max_guesses = min(max_guesses, pred.size(1)) if max_guesses == pred.size(1): if prob is not None: prob = prob / prob.sum(dim=-1, keepdim=True) else: prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses return pred, prob, None else: if prob is not None: if joint: if ptr is None: inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), k=max_guesses, dim=-1, largest=True, sorted=True)[1] inds_topk = inds_topk.repeat(pred.size(0), 1) else: inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, reduce='mean'), k=max_guesses, dim=-1, largest=True, sorted=True)[1] inds_topk = gather_csr(src=inds_topk, indptr=ptr) else: inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) else: pred_topk = pred[:, :max_guesses] prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses return pred_topk, prob_topk, inds_topk def valid_filter( pred: torch.Tensor, target: torch.Tensor, prob: Optional[torch.Tensor] = None, valid_mask: Optional[torch.Tensor] = None, ptr: Optional[torch.Tensor] = None, keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor]: if valid_mask is None: valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool) if keep_invalid_final_step: filter_mask = valid_mask.any(dim=-1) else: filter_mask = valid_mask[:, -1] pred = pred[filter_mask] target = target[filter_mask] if prob is not None: prob = prob[filter_mask] valid_mask = valid_mask[filter_mask] if ptr is not None: num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum') ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,)) torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:]) else: ptr = target.new_tensor([0, target.size(0)]) return pred, target, prob, valid_mask, ptr def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6): """ Args: pred_trajs (batch_size, num_modes, num_timestamps, 7) pred_scores (batch_size, num_modes): dist_thresh (float): num_ret_modes (int, optional): Defaults to 6. Returns: ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) ret_scores (batch_size, num_ret_modes) ret_idxs (batch_size, num_ret_modes) """ batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape pred_goals = pred_trajs[:, :, -1, :] dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1) nearby_neighbor = dist < dist_thresh pred_scores = nearby_neighbor.sum(dim=-1) / num_modes sorted_idxs = pred_scores.argsort(dim=-1, descending=True) bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) point_cover_mask = (dist < dist_thresh) point_val = sorted_pred_scores.clone() # (batch_size, N) point_val_selected = torch.zeros_like(point_val) # (batch_size, N) ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) bs_idxs = torch.arange(batch_size).type_as(ret_idxs) for k in range(num_ret_modes): cur_idx = point_val.argmax(dim=-1) # (batch_size) ret_idxs[:, k] = cur_idx new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) point_val = point_val * (~new_cover_mask).float() # (batch_size, N) point_val_selected[bs_idxs, cur_idx] = -1 point_val += point_val_selected ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) ret_idxs = sorted_idxs[bs_idxs, ret_idxs] return ret_trajs, ret_scores, ret_idxs def batch_nms(pred_trajs, pred_scores, dist_thresh, num_ret_modes=6, mode='static', speed=None): """ Args: pred_trajs (batch_size, num_modes, num_timestamps, 7) pred_scores (batch_size, num_modes): dist_thresh (float): num_ret_modes (int, optional): Defaults to 6. Returns: ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) ret_scores (batch_size, num_ret_modes) ret_idxs (batch_size, num_ret_modes) """ batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape sorted_idxs = pred_scores.argsort(dim=-1, descending=True) bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7) if mode == "speed": scale = torch.ones(batch_size).to(sorted_pred_goals.device) lon_dist_thresh = 4 * scale lat_dist_thresh = 0.5 * scale lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1) lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1) point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None]) else: dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) point_cover_mask = (dist < dist_thresh) point_val = sorted_pred_scores.clone() # (batch_size, N) point_val_selected = torch.zeros_like(point_val) # (batch_size, N) ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) bs_idxs = torch.arange(batch_size).type_as(ret_idxs) for k in range(num_ret_modes): cur_idx = point_val.argmax(dim=-1) # (batch_size) ret_idxs[:, k] = cur_idx new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) point_val = point_val * (~new_cover_mask).float() # (batch_size, N) point_val_selected[bs_idxs, cur_idx] = -1 point_val += point_val_selected ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) ret_idxs = sorted_idxs[bs_idxs, ret_idxs] return ret_trajs, ret_scores, ret_idxs def batch_nms_token(pred_trajs, pred_scores, dist_thresh, num_ret_modes=6, mode='static', speed=None): """ Args: pred_trajs (batch_size, num_modes, num_timestamps, 7) pred_scores (batch_size, num_modes): dist_thresh (float): num_ret_modes (int, optional): Defaults to 6. Returns: ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) ret_scores (batch_size, num_ret_modes) ret_idxs (batch_size, num_ret_modes) """ batch_size, num_modes, num_feat_dim = pred_trajs.shape sorted_idxs = pred_scores.argsort(dim=-1, descending=True) bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7) if mode == "nearby": dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) values, indices = torch.topk(dist, 5, dim=-1, largest=False) thresh_hold = values[..., -1] point_cover_mask = dist < thresh_hold[..., None] else: dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) point_cover_mask = (dist < dist_thresh) point_val = sorted_pred_scores.clone() # (batch_size, N) point_val_selected = torch.zeros_like(point_val) # (batch_size, N) ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim) ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes) bs_idxs = torch.arange(batch_size).type_as(ret_idxs) for k in range(num_ret_modes): cur_idx = point_val.argmax(dim=-1) # (batch_size) ret_idxs[:, k] = cur_idx new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N) point_val = point_val * (~new_cover_mask).float() # (batch_size, N) point_val_selected[bs_idxs, cur_idx] = -1 point_val += point_val_selected ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx] ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) ret_idxs = sorted_idxs[bs_idxs, ret_idxs] return ret_goals, ret_scores, ret_idxs ================================================ FILE: smart/model/__init__.py ================================================ from smart.model.smart import SMART ================================================ FILE: smart/model/smart.py ================================================ import contextlib import pytorch_lightning as pl import torch import torch.nn as nn from torch_geometric.data import Batch from torch_geometric.data import HeteroData from smart.metrics import minADE from smart.metrics import minFDE from smart.metrics import TokenCls from smart.modules import SMARTDecoder from torch.optim.lr_scheduler import LambdaLR import math import numpy as np import pickle from collections import defaultdict import os from waymo_open_dataset.protos import sim_agents_submission_pb2 def cal_polygon_contour(x, y, theta, width, length): left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) left_front = (left_front_x, left_front_y) right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) right_front = (right_front_x, right_front_y) right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) right_back = (right_back_x, right_back_y) left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) left_back = (left_back_x, left_back_y) polygon_contour = [left_front, right_front, right_back, left_back] return polygon_contour def joint_scene_from_states(states, object_ids) -> sim_agents_submission_pb2.JointScene: states = states.numpy() simulated_trajectories = [] for i_object in range(len(object_ids)): simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory( center_x=states[i_object, :, 0], center_y=states[i_object, :, 1], center_z=states[i_object, :, 2], heading=states[i_object, :, 3], object_id=object_ids[i_object].item() )) return sim_agents_submission_pb2.JointScene(simulated_trajectories=simulated_trajectories) class SMART(pl.LightningModule): def __init__(self, model_config) -> None: super(SMART, self).__init__() self.save_hyperparameters() self.model_config = model_config self.warmup_steps = model_config.warmup_steps self.lr = model_config.lr self.total_steps = model_config.total_steps self.dataset = model_config.dataset self.input_dim = model_config.input_dim self.hidden_dim = model_config.hidden_dim self.output_dim = model_config.output_dim self.output_head = model_config.output_head self.num_historical_steps = model_config.num_historical_steps self.num_future_steps = model_config.decoder.num_future_steps self.num_freq_bands = model_config.num_freq_bands self.vis_map = False self.noise = True module_dir = os.path.dirname(os.path.dirname(__file__)) self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl') self.init_map_token() self.token_path = os.path.join(module_dir, 'tokens/cluster_frame_5_2048.pkl') token_data = self.get_trajectory_token() self.encoder = SMARTDecoder( dataset=model_config.dataset, input_dim=model_config.input_dim, hidden_dim=model_config.hidden_dim, num_historical_steps=model_config.num_historical_steps, num_freq_bands=model_config.num_freq_bands, num_heads=model_config.num_heads, head_dim=model_config.head_dim, dropout=model_config.dropout, num_map_layers=model_config.decoder.num_map_layers, num_agent_layers=model_config.decoder.num_agent_layers, pl2pl_radius=model_config.decoder.pl2pl_radius, pl2a_radius=model_config.decoder.pl2a_radius, a2a_radius=model_config.decoder.a2a_radius, time_span=model_config.decoder.time_span, map_token={'traj_src': self.map_token['traj_src']}, token_data=token_data, token_size=model_config.decoder.token_size ) self.minADE = minADE(max_guesses=1) self.minFDE = minFDE(max_guesses=1) self.TokenCls = TokenCls(max_guesses=1) self.test_predictions = dict() self.cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1) self.map_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1) self.inference_token = False self.rollout_num = 1 def get_trajectory_token(self): token_data = pickle.load(open(self.token_path, 'rb')) self.trajectory_token = token_data['token'] self.trajectory_token_traj = token_data['traj'] self.trajectory_token_all = token_data['token_all'] return token_data def init_map_token(self): self.argmin_sample_len = 3 map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb')) self.map_token = {'traj_src': map_token_traj['traj_src'], } traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1], self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0]) indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long() self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float) self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float) self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float) def forward(self, data: HeteroData): res = self.encoder(data) return res def inference(self, data: HeteroData): res = self.encoder.inference(data) return res def maybe_autocast(self, dtype=torch.float16): enable_autocast = self.device != torch.device("cpu") if enable_autocast: return torch.cuda.amp.autocast(dtype=dtype) else: return contextlib.nullcontext() def training_step(self, data, batch_idx): data = self.match_token_map(data) data = self.sample_pt_pred(data) if isinstance(data, Batch): data['agent']['av_index'] += data['agent']['ptr'][:-1] pred = self(data) next_token_prob = pred['next_token_prob'] next_token_idx_gt = pred['next_token_idx_gt'] next_token_eval_mask = pred['next_token_eval_mask'] cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask]) loss = cls_loss self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) self.log('cls_loss', cls_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1) return loss def validation_step(self, data, batch_idx): data = self.match_token_map(data) data = self.sample_pt_pred(data) if isinstance(data, Batch): data['agent']['av_index'] += data['agent']['ptr'][:-1] pred = self(data) next_token_idx = pred['next_token_idx'] next_token_idx_gt = pred['next_token_idx_gt'] next_token_eval_mask = pred['next_token_eval_mask'] next_token_prob = pred['next_token_prob'] cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask]) loss = cls_loss self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask], valid_mask=next_token_eval_mask[next_token_eval_mask]) self.log('val_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1] # * (data['agent']['category'] == 3) if self.inference_token: pred = self.inference(data) pos_a = pred['pos_a'] gt = pred['gt'] valid_mask = data['agent']['valid_mask'][:, self.num_historical_steps:] pred_traj = pred['pred_traj'] # next_token_idx = pred['next_token_idx'][..., None] # next_token_idx_gt = pred['next_token_idx_gt'][:, 2:] # next_token_eval_mask = pred['next_token_eval_mask'][:, 2:] # next_token_eval_mask[:, 1:] = False # self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask], # valid_mask=next_token_eval_mask[next_token_eval_mask]) # self.log('val_inference_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True) eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1] self.minADE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask]) self.minFDE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask]) # print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute()) self.log('val_minADE', self.minADE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1) self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1) def on_validation_start(self): self.gt = [] self.pred = [] self.scenario_rollouts = [] self.batch_metric = defaultdict(list) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) def lr_lambda(current_step): if current_step + 1 < self.warmup_steps: return float(current_step + 1) / float(max(1, self.warmup_steps)) return max( 0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps)))) ) lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) return [optimizer], [lr_scheduler] def load_params_from_file(self, filename, logger, to_cpu=False): if not os.path.isfile(filename): raise FileNotFoundError logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU')) loc_type = torch.device('cpu') if to_cpu else None checkpoint = torch.load(filename, map_location=loc_type) model_state_disk = checkpoint['state_dict'] version = checkpoint.get("version", None) if version is not None: logger.info('==> Checkpoint trained from version: %s' % version) logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}') model_state = self.state_dict() model_state_disk_filter = {} for key, val in model_state_disk.items(): if key in model_state and model_state_disk[key].shape == model_state[key].shape: model_state_disk_filter[key] = val else: if key not in model_state: print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}') else: print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}') model_state_disk = model_state_disk_filter missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False) logger.info(f'Missing keys: {missing_keys}') logger.info(f'The number of missing keys: {len(missing_keys)}') logger.info(f'The number of unexpected keys: {len(unexpected_keys)}') logger.info('==> Done (total keys %d)' % (len(model_state))) epoch = checkpoint.get('epoch', -1) it = checkpoint.get('it', 0.0) return it, epoch def match_token_map(self, data): traj_pos = data['map_save']['traj_pos'].to(torch.float) traj_theta = data['map_save']['traj_theta'].to(torch.float) pl_idx_list = data['map_save']['pl_idx_list'] token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device) token_src = self.map_token['traj_src'].to(traj_pos.device) max_traj_len = self.map_token['traj_src'].shape[1] pl_num = traj_pos.shape[0] pt_token_pos = traj_pos[:, 0, :].clone() pt_token_orientation = traj_theta.clone() cos, sin = traj_theta.cos(), traj_theta.sin() rot_mat = traj_theta.new_zeros(pl_num, 2, 2) rot_mat[..., 0, 0] = cos rot_mat[..., 0, 1] = -sin rot_mat[..., 1, 0] = sin rot_mat[..., 1, 1] = cos traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2)) distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)) pt_token_id = torch.argmin(distance, dim=1) if self.noise: topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8] sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device) pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1) cos, sin = traj_theta.cos(), traj_theta.sin() rot_mat = traj_theta.new_zeros(pl_num, 2, 2) rot_mat[..., 0, 0] = cos rot_mat[..., 0, 1] = sin rot_mat[..., 1, 0] = -sin rot_mat[..., 1, 1] = cos token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2), rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :] token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2) pl_idx_full = pl_idx_list.clone() token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()]) count_nums = [] for pl in pl_idx_full.unique(): pt = token2pl[0, token2pl[1, :] == pl] left_side = (data['pt_token']['side'][pt] == 0).sum() right_side = (data['pt_token']['side'][pt] == 1).sum() center_side = (data['pt_token']['side'][pt] == 2).sum() count_nums.append(torch.Tensor([left_side, right_side, center_side])) count_nums = torch.stack(count_nums, dim=0) num_polyline = int(count_nums.max().item()) traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool) idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0) idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1) # counts_num_expanded = count_nums.unsqueeze(-1) mask_update = idx_matrix < counts_num_expanded traj_mask[mask_update] = True data['pt_token']['traj_mask'] = traj_mask data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1), device=traj_pos.device, dtype=torch.float)], dim=-1) data['pt_token']['orientation'] = pt_token_orientation data['pt_token']['height'] = data['pt_token']['position'][:, -1] data[('pt_token', 'to', 'map_polygon')] = {} data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl data['pt_token']['token_idx'] = pt_token_id return data def sample_pt_pred(self, data): traj_mask = data['pt_token']['traj_mask'] raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1) masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)] masked_pt_index = torch.sort(masked_pt_index, -1)[0] pt_valid_mask = traj_mask.clone() pt_valid_mask.scatter_(2, masked_pt_index, False) pt_pred_mask = traj_mask.clone() pt_pred_mask.scatter_(2, masked_pt_index, False) tmp_mask = pt_pred_mask.clone() tmp_mask[:, :, :] = True tmp_mask.scatter_(2, masked_pt_index-1, False) pt_pred_mask.masked_fill_(tmp_mask, False) pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2) pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2) data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask] data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask] data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask] return data ================================================ FILE: smart/modules/__init__.py ================================================ from smart.modules.smart_decoder import SMARTDecoder from smart.modules.map_decoder import SMARTMapDecoder from smart.modules.agent_decoder import SMARTAgentDecoder ================================================ FILE: smart/modules/agent_decoder.py ================================================ import pickle from typing import Dict, Mapping, Optional import torch import torch.nn as nn from smart.layers import MLPLayer from smart.layers.attention_layer import AttentionLayer from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding from torch_cluster import radius, radius_graph from torch_geometric.data import Batch, HeteroData from torch_geometric.utils import dense_to_sparse, subgraph from smart.utils import angle_between_2d_vectors, weight_init, wrap_angle import math def cal_polygon_contour(x, y, theta, width, length): left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) left_front = (left_front_x, left_front_y) right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) right_front = (right_front_x, right_front_y) right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta) right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta) right_back = (right_back_x, right_back_y) left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta) left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta) left_back = (left_back_x, left_back_y) polygon_contour = [left_front, right_front, right_back, left_back] return polygon_contour class SMARTAgentDecoder(nn.Module): def __init__(self, dataset: str, input_dim: int, hidden_dim: int, num_historical_steps: int, time_span: Optional[int], pl2a_radius: float, a2a_radius: float, num_freq_bands: int, num_layers: int, num_heads: int, head_dim: int, dropout: float, token_data: Dict, token_size=512) -> None: super(SMARTAgentDecoder, self).__init__() self.dataset = dataset self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_historical_steps = num_historical_steps self.time_span = time_span if time_span is not None else num_historical_steps self.pl2a_radius = pl2a_radius self.a2a_radius = a2a_radius self.num_freq_bands = num_freq_bands self.num_layers = num_layers self.num_heads = num_heads self.head_dim = head_dim self.dropout = dropout input_dim_x_a = 2 input_dim_r_t = 4 input_dim_r_pt2a = 3 input_dim_r_a2a = 3 input_dim_token = 8 self.type_a_emb = nn.Embedding(4, hidden_dim) self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim) self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim) self.t_attn_layers = nn.ModuleList( [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, bipartite=False, has_pos_emb=True) for _ in range(num_layers)] ) self.pt2a_attn_layers = nn.ModuleList( [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, bipartite=True, has_pos_emb=True) for _ in range(num_layers)] ) self.a2a_attn_layers = nn.ModuleList( [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, bipartite=False, has_pos_emb=True) for _ in range(num_layers)] ) self.token_size = token_size self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=self.token_size) self.trajectory_token = token_data['token'] self.trajectory_token_traj = token_data['traj'] self.trajectory_token_all = token_data['token_all'] self.apply(weight_init) self.shift = 5 self.beam_size = 5 self.hist_mask = True def transform_rel(self, token_traj, prev_pos, prev_heading=None): if prev_heading is None: diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :] prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) num_agent, num_step, traj_num, traj_dim = token_traj.shape cos, sin = prev_heading.cos(), prev_heading.sin() rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device) rot_mat[:, :, 0, 0] = cos rot_mat[:, :, 0, 1] = -sin rot_mat[:, :, 1, 0] = sin rot_mat[:, :, 1, 1] = cos agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim) agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :] return agent_pred_rel def agent_token_embedding(self, data, agent_category, agent_token_index, pos_a, head_vector_a, inference=False): num_agent, num_step, traj_dim = pos_a.shape motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim), pos_a[:, 1:] - pos_a[:, :-1]], dim=1) agent_type = data['agent']['type'] veh_mask = (agent_type == 0) cyc_mask = (agent_type == 2) ped_mask = (agent_type == 1) trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float) self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1)) trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float) self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1)) trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float) self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1)) if inference: agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device) trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to( torch.float) trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to( torch.float) trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to( torch.float) agent_token_traj_all[veh_mask] = torch.cat( [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1) agent_token_traj_all[ped_mask] = torch.cat( [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1) agent_token_traj_all[cyc_mask] = torch.cat( [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1) agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device) agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]] agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]] agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]] agent_token_traj = torch.zeros((num_agent, num_step, self.token_size, 4, 2), device=pos_a.device) agent_token_traj[veh_mask] = trajectory_token_veh agent_token_traj[ped_mask] = trajectory_token_ped agent_token_traj[cyc_mask] = trajectory_token_cyc vel = data['agent']['token_velocity'] categorical_embs = [ self.type_a_emb(data['agent']['type'].long()).repeat_interleave(repeats=num_step, dim=0), self.shape_emb(data['agent']['shape'][:, self.num_historical_steps - 1, :]).repeat_interleave( repeats=num_step, dim=0) ] feature_a = torch.stack( [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]), ], dim=-1) x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)), categorical_embs=categorical_embs) x_a = x_a.view(-1, num_step, self.hidden_dim) feat_a = torch.cat((agent_token_emb, x_a), dim=-1) feat_a = self.fusion_emb(feat_a) if inference: return feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs else: return feat_a, agent_token_traj def agent_predict_next(self, data, agent_category, feat_a): num_agent, num_step, traj_dim = data['agent']['token_pos'].shape agent_type = data['agent']['type'] veh_mask = (agent_type == 0) # * agent_category==3 cyc_mask = (agent_type == 2) # * agent_category==3 ped_mask = (agent_type == 1) # * agent_category==3 token_res = torch.zeros((num_agent, num_step, self.token_size), device=agent_category.device) token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask]) token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask]) token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask]) return token_res def agent_predict_next_inf(self, data, agent_category, feat_a): num_agent, traj_dim = feat_a.shape agent_type = data['agent']['type'] veh_mask = (agent_type == 0) # * agent_category==3 cyc_mask = (agent_type == 2) # * agent_category==3 ped_mask = (agent_type == 1) # * agent_category==3 token_res = torch.zeros((num_agent, self.token_size), device=agent_category.device) token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask]) token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask]) token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask]) return token_res def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent, mask, inference_mask=None): pos_t = pos_a.reshape(-1, self.input_dim) head_t = head_a.reshape(-1) head_vector_t = head_vector_a.reshape(-1, 2) hist_mask = mask.clone() if self.hist_mask and self.training: hist_mask[ torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) elif inference_mask is not None: mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1) else: mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1) edge_index_t = dense_to_sparse(mask_t)[0] edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]] edge_index_t = edge_index_t[:, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift] rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]] rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]]) r_t = torch.stack( [torch.norm(rel_pos_t[:, :2], p=2, dim=-1), angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]), rel_head_t, edge_index_t[0] - edge_index_t[1]], dim=-1) r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None) return edge_index_t, r_t def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s, mask_s): pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) head_s = head_a.transpose(0, 1).reshape(-1) head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False, max_num_neighbors=300) edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0] rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]] rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]]) r_a2a = torch.stack( [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1), angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]), rel_head_a2a], dim=-1) r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None) return edge_index_a2a, r_a2a def build_map2agent_edge(self, data, num_step, agent_category, pos_a, head_a, head_vector_a, mask, batch_s, batch_pl): mask_pl2a = mask.clone() mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1) pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim) head_s = head_a.transpose(0, 1).reshape(-1) head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2) pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous() orient_pl = data['pt_token']['orientation'].contiguous() pos_pl = pos_pl.repeat(num_step, 1) orient_pl = orient_pl.repeat(num_step) edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius, batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300) edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]] rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]] rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]]) r_pl2a = torch.stack( [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1), angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]), rel_orient_pl2a], dim=-1) r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None) return edge_index_pl2a, r_pl2a def forward(self, data: HeteroData, map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: pos_a = data['agent']['token_pos'] head_a = data['agent']['token_heading'] head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) num_agent, num_step, traj_dim = pos_a.shape agent_category = data['agent']['category'] agent_token_index = data['agent']['token_idx'] feat_a, agent_token_traj = self.agent_token_embedding(data, agent_category, agent_token_index, pos_a, head_vector_a) agent_valid_mask = data['agent']['agent_valid_mask'].clone() # eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1] # agent_valid_mask[~eval_mask] = False mask = agent_valid_mask edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask) if isinstance(data, Batch): batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) else: batch_s = torch.arange(num_step, device=pos_a.device).repeat_interleave(data['agent']['num_nodes']) batch_pl = torch.arange(num_step, device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes']) mask_s = mask.transpose(0, 1).reshape(-1) edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, batch_s, mask_s) mask[agent_category != 3] = False edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a, head_vector_a, mask, batch_s, batch_pl) for i in range(self.num_layers): feat_a = feat_a.reshape(-1, self.hidden_dim) feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave( repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) num_agent, num_step, hidden_dim, traj_num, traj_dim = agent_token_traj.shape next_token_prob = self.token_predict_head(feat_a) next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1) next_token_eval_mask = mask.clone() next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1) next_token_eval_mask[:, -1] = False return {'x_a': feat_a, 'next_token_idx': next_token_idx, 'next_token_prob': next_token_prob, 'next_token_idx_gt': next_token_index_gt, 'next_token_eval_mask': next_token_eval_mask, } def inference(self, data: HeteroData, map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1] pos_a = data['agent']['token_pos'].clone() head_a = data['agent']['token_heading'].clone() num_agent, num_step, traj_dim = pos_a.shape pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0 head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) agent_valid_mask = data['agent']['agent_valid_mask'].clone() agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True agent_valid_mask[~eval_mask] = False agent_token_index = data['agent']['token_idx'] agent_category = data['agent']['category'] feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding( data, agent_category, agent_token_index, pos_a, head_vector_a, inference=True) agent_type = data["agent"]["type"] veh_mask = (agent_type == 0) # * agent_category==3 cyc_mask = (agent_type == 2) # * agent_category==3 ped_mask = (agent_type == 1) # * agent_category==3 av_mask = data["agent"]["av_index"] self.num_recurrent_steps_val = data["agent"]['position'].shape[1]-self.num_historical_steps pred_traj = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, 2, device=feat_a.device) pred_head = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, device=feat_a.device) pred_prob = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val // self.shift, device=feat_a.device) next_token_idx_list = [] mask = agent_valid_mask.clone() feat_a_t_dict = {} for t in range(self.num_recurrent_steps_val // self.shift): if t == 0: inference_mask = mask.clone() inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False else: inference_mask = torch.zeros_like(mask) inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask, inference_mask) if isinstance(data, Batch): batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t for t in range(num_step)], dim=0) else: batch_s = torch.arange(num_step, device=pos_a.device).repeat_interleave(data['agent']['num_nodes']) batch_pl = torch.arange(num_step, device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes']) # In the inference stage, we only infer the current stage for recurrent edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a, head_vector_a, inference_mask, batch_s, batch_pl) mask_s = inference_mask.transpose(0, 1).reshape(-1) edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, batch_s, mask_s) for i in range(self.num_layers): if i in feat_a_t_dict: feat_a = feat_a_t_dict[i] feat_a = feat_a.reshape(-1, self.hidden_dim) feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t) feat_a = feat_a.reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave( repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape( -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a) feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a) feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1) if i+1 not in feat_a_t_dict: feat_a_t_dict[i+1] = feat_a else: feat_a_t_dict[i+1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]) next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) topk_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1) expanded_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2) next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_index) theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t] cos, sin = theta.cos(), theta.sin() rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device) rot_mat[:, 0, 0] = cos rot_mat[:, 0, 1] = sin rot_mat[:, 1, 0] = -sin rot_mat[:, 1, 1] = cos agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2), rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view( -1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2) agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...] sample_index = torch.multinomial(topk_prob, 1).to(agent_pred_rel.device) agent_pred_rel = agent_pred_rel.gather(dim=1, index=sample_index[..., None, None, None].expand(-1, -1, 6, 4, 2))[:, 0, ...] pred_prob[:, t] = topk_prob.gather(dim=-1, index=sample_index)[:, 0] pred_traj[:, t * 5:(t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2) diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :] pred_head[:, t * 5:(t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0]) pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1) diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :] theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0]) head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta next_token_idx = next_token_idx.gather(dim=1, index=sample_index) next_token_idx = next_token_idx.squeeze(-1) next_token_idx_list.append(next_token_idx[:, None]) agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[ next_token_idx[veh_mask]] agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[ next_token_idx[ped_mask]] agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[ next_token_idx[cyc_mask]] motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim), pos_a[:, 1:] - pos_a[:, :-1]], dim=1) head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1) vel = motion_vector_a.clone() / (0.1 * self.shift) vel[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0 motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0 x_a = torch.stack( [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1), angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1) x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)), categorical_embs=categorical_embs) x_a = x_a.view(-1, num_step, self.hidden_dim) feat_a = torch.cat((agent_token_emb, x_a), dim=-1) feat_a = self.fusion_emb(feat_a) agent_valid_mask[agent_category != 3] = False return { 'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:], 'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:], 'gt': data['agent']['position'][:, self.num_historical_steps:, :self.input_dim].contiguous(), 'valid_mask': agent_valid_mask[:, self.num_historical_steps:], 'pred_traj': pred_traj, 'pred_head': pred_head, 'next_token_idx': torch.cat(next_token_idx_list, dim=-1), 'next_token_idx_gt': agent_token_index.roll(shifts=-1, dims=1), 'next_token_eval_mask': data['agent']['agent_valid_mask'], 'pred_prob': pred_prob, 'vel': vel } ================================================ FILE: smart/modules/map_decoder.py ================================================ import os.path from typing import Dict import torch import torch.nn as nn from torch_cluster import radius_graph from torch_geometric.data import Batch from torch_geometric.data import HeteroData from torch_geometric.utils import dense_to_sparse, subgraph from smart.utils.nan_checker import check_nan_inf from smart.layers.attention_layer import AttentionLayer from smart.layers import MLPLayer from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding from smart.utils import angle_between_2d_vectors from smart.utils import merge_edges from smart.utils import weight_init from smart.utils import wrap_angle import pickle class SMARTMapDecoder(nn.Module): def __init__(self, dataset: str, input_dim: int, hidden_dim: int, num_historical_steps: int, pl2pl_radius: float, num_freq_bands: int, num_layers: int, num_heads: int, head_dim: int, dropout: float, map_token) -> None: super(SMARTMapDecoder, self).__init__() self.dataset = dataset self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_historical_steps = num_historical_steps self.pl2pl_radius = pl2pl_radius self.num_freq_bands = num_freq_bands self.num_layers = num_layers self.num_heads = num_heads self.head_dim = head_dim self.dropout = dropout if input_dim == 2: input_dim_r_pt2pt = 3 elif input_dim == 3: input_dim_r_pt2pt = 4 else: raise ValueError('{} is not a valid dimension'.format(input_dim)) self.type_pt_emb = nn.Embedding(17, hidden_dim) self.side_pt_emb = nn.Embedding(4, hidden_dim) self.polygon_type_emb = nn.Embedding(4, hidden_dim) self.light_pl_emb = nn.Embedding(4, hidden_dim) self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands) self.pt2pt_layers = nn.ModuleList( [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout, bipartite=False, has_pos_emb=True) for _ in range(num_layers)] ) self.token_size = 1024 self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=self.token_size) input_dim_token = 22 self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim) self.map_token = map_token self.apply(weight_init) self.mask_pt = False def maybe_autocast(self, dtype=torch.float32): return torch.cuda.amp.autocast(dtype=dtype) def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]: pt_valid_mask = data['pt_token']['pt_valid_mask'] pt_pred_mask = data['pt_token']['pt_pred_mask'] pt_target_mask = data['pt_token']['pt_target_mask'] mask_s = pt_valid_mask pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous() orient_pt = data['pt_token']['orientation'].contiguous() orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1) token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float) pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1)) pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']] if self.input_dim == 2: x_pt = pt_token_emb elif self.input_dim == 3: x_pt = pt_token_emb else: raise ValueError('{} is not a valid dimension'.format(self.input_dim)) token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index'] token_light_type = data['map_polygon']['light_type'][token2pl[1]] x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()), self.polygon_type_emb(data['pt_token']['pl_type'].long()), self.light_pl_emb(token_light_type.long()),] x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0) edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius, batch=data['pt_token']['batch'] if isinstance(data, Batch) else None, loop=False, max_num_neighbors=100) if self.mask_pt: edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0] rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]] rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]]) if self.input_dim == 2: r_pt2pt = torch.stack( [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], nbr_vector=rel_pos_pt2pt[:, :2]), rel_orient_pt2pt], dim=-1) elif self.input_dim == 3: r_pt2pt = torch.stack( [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1), angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]], nbr_vector=rel_pos_pt2pt[:, :2]), rel_pos_pt2pt[:, -1], rel_orient_pt2pt], dim=-1) else: raise ValueError('{} is not a valid dimension'.format(self.input_dim)) r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None) for i in range(self.num_layers): x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt) next_token_prob = self.token_predict_head(x_pt[pt_pred_mask]) next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1) _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1) next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask] return { 'x_pt': x_pt, 'map_next_token_idx': next_token_idx, 'map_next_token_prob': next_token_prob, 'map_next_token_idx_gt': next_token_index_gt, 'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask] } ================================================ FILE: smart/modules/smart_decoder.py ================================================ from typing import Dict, Optional import torch import torch.nn as nn from torch_geometric.data import HeteroData from smart.modules.agent_decoder import SMARTAgentDecoder from smart.modules.map_decoder import SMARTMapDecoder class SMARTDecoder(nn.Module): def __init__(self, dataset: str, input_dim: int, hidden_dim: int, num_historical_steps: int, pl2pl_radius: float, time_span: Optional[int], pl2a_radius: float, a2a_radius: float, num_freq_bands: int, num_map_layers: int, num_agent_layers: int, num_heads: int, head_dim: int, dropout: float, map_token: Dict, token_data: Dict, use_intention=False, token_size=512) -> None: super(SMARTDecoder, self).__init__() self.map_encoder = SMARTMapDecoder( dataset=dataset, input_dim=input_dim, hidden_dim=hidden_dim, num_historical_steps=num_historical_steps, pl2pl_radius=pl2pl_radius, num_freq_bands=num_freq_bands, num_layers=num_map_layers, num_heads=num_heads, head_dim=head_dim, dropout=dropout, map_token=map_token ) self.agent_encoder = SMARTAgentDecoder( dataset=dataset, input_dim=input_dim, hidden_dim=hidden_dim, num_historical_steps=num_historical_steps, time_span=time_span, pl2a_radius=pl2a_radius, a2a_radius=a2a_radius, num_freq_bands=num_freq_bands, num_layers=num_agent_layers, num_heads=num_heads, head_dim=head_dim, dropout=dropout, token_size=token_size, token_data=token_data ) self.map_enc = None def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]: map_enc = self.map_encoder(data) agent_enc = self.agent_encoder(data, map_enc) return {**map_enc, **agent_enc} def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]: map_enc = self.map_encoder(data) agent_enc = self.agent_encoder.inference(data, map_enc) return {**map_enc, **agent_enc} def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]: agent_enc = self.agent_encoder.inference(data, map_enc) return {**map_enc, **agent_enc} ================================================ FILE: smart/preprocess/__init__.py ================================================ ================================================ FILE: smart/preprocess/preprocess.py ================================================ import numpy as np import pandas as pd import os import torch from typing import Any, Dict, List, Optional predict_unseen_agents = False vector_repr = True _agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background'] _polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN'] _polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN'] _point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW', 'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE', 'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE', 'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE'] _point_sides = ['LEFT', 'RIGHT', 'CENTER'] _polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT'] _polygon_is_intersections = [True, False, None] Lane_type_hash = { 4: "BIKE", 3: "VEHICLE", 2: "VEHICLE", 1: "BUS" } boundary_type_hash = { 5: "UNKNOWN", 6: "DASHED_WHITE", 7: "SOLID_WHITE", 8: "DOUBLE_DASH_WHITE", 9: "DASHED_YELLOW", 10: "DOUBLE_DASH_YELLOW", 11: "SOLID_YELLOW", 12: "DOUBLE_SOLID_YELLOW", 13: "DASH_SOLID_YELLOW", 14: "UNKNOWN", 15: "EDGE", 16: "EDGE" } def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]: if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps historical_df = df[df['timestep'] == num_historical_steps-1] agent_ids = list(historical_df['track_id'].unique()) df = df[df['track_id'].isin(agent_ids)] else: agent_ids = list(df['track_id'].unique()) num_agents = len(agent_ids) # initialization valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) current_valid_mask = torch.zeros(num_agents, dtype=torch.bool) predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool) agent_id: List[Optional[str]] = [None] * num_agents agent_type = torch.zeros(num_agents, dtype=torch.uint8) agent_category = torch.zeros(num_agents, dtype=torch.uint8) position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) heading = torch.zeros(num_agents, num_steps, dtype=torch.float) velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float) for track_id, track_df in df.groupby('track_id'): agent_idx = agent_ids.index(track_id) agent_steps = track_df['timestep'].values valid_mask[agent_idx, agent_steps] = True current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1] predict_mask[agent_idx, agent_steps] = True if vector_repr: # a time step t is valid only when both t and t-1 are valid valid_mask[agent_idx, 1: num_historical_steps] = ( valid_mask[agent_idx, :num_historical_steps - 1] & valid_mask[agent_idx, 1: num_historical_steps]) valid_mask[agent_idx, 0] = False predict_mask[agent_idx, :num_historical_steps] = False if not current_valid_mask[agent_idx]: predict_mask[agent_idx, num_historical_steps:] = False agent_id[agent_idx] = track_id agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0]) agent_category[agent_idx] = track_df['object_category'].values[0] position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values, track_df['position_y'].values, track_df['position_z'].values], axis=-1)).float() heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float() velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values, track_df['velocity_y'].values], axis=-1)).float() shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values, track_df['width'].values, track_df["height"].values], axis=-1)).float() av_idx = agent_id.index(av_id) return { 'num_nodes': num_agents, 'av_index': av_idx, 'valid_mask': valid_mask, 'predict_mask': predict_mask, 'id': agent_id, 'type': agent_type, 'category': agent_category, 'position': position, 'heading': heading, 'velocity': velocity, 'shape': shape } ================================================ FILE: smart/tokens/__init__.py ================================================ ================================================ FILE: smart/transforms/__init__.py ================================================ from smart.transforms.target_builder import WaymoTargetBuilder ================================================ FILE: smart/transforms/target_builder.py ================================================ import numpy as np import torch from torch_geometric.data import HeteroData from torch_geometric.transforms import BaseTransform from smart.utils import wrap_angle from smart.utils.log import Logging def to_16(data): if isinstance(data, dict): for key, value in data.items(): new_value = to_16(value) data[key] = new_value if isinstance(data, torch.Tensor): if data.dtype == torch.float32: data = data.to(torch.float16) return data def tofloat32(data): for name in data: value = data[name] if isinstance(value, dict): value = tofloat32(value) elif isinstance(value, torch.Tensor) and value.dtype == torch.float64: value = value.to(torch.float32) data[name] = value return data class WaymoTargetBuilder(BaseTransform): def __init__(self, num_historical_steps: int, num_future_steps: int, mode="train") -> None: self.num_historical_steps = num_historical_steps self.num_future_steps = num_future_steps self.mode = mode self.num_features = 3 self.augment = False self.logger = Logging().log(level='DEBUG') def score_ego_agent(self, agent): av_index = agent['av_index'] agent["category"][av_index] = 5 return agent def clip(self, agent, max_num=32): av_index = agent["av_index"] valid = agent['valid_mask'] ego_pos = agent["position"][av_index] obstacle_mask = agent['type'] == 3 distance = torch.norm(agent["position"][:, self.num_historical_steps-1, :2] - ego_pos[self.num_historical_steps-1, :2], dim=-1) # keep the closest 100 vehicles near the ego car distance[obstacle_mask] = 10e5 sort_idx = distance.sort()[1] mask = torch.zeros(valid.shape[0]) mask[sort_idx[:max_num]] = 1 mask = mask.to(torch.bool) mask[av_index] = True new_av_index = mask[:av_index].sum() agent["num_nodes"] = int(mask.sum()) agent["av_index"] = int(new_av_index) excluded = ["num_nodes", "av_index", "ego"] for key, val in agent.items(): if key in excluded: continue if key == "id": val = list(np.array(val)[mask]) agent[key] = val continue if len(val.size()) > 1: agent[key] = val[mask, ...] else: agent[key] = val[mask] return agent def score_nearby_vehicle(self, agent, max_num=10): av_index = agent['av_index'] agent["category"] = torch.zeros_like(agent["category"]) obstacle_mask = agent['type'] == 3 pos = agent["position"][av_index, self.num_historical_steps, :2] distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1) distance[obstacle_mask] = 10e5 sort_idx = distance.sort()[1] nearby_mask = torch.zeros(distance.shape[0]) nearby_mask[sort_idx[1:max_num]] = 1 nearby_mask = nearby_mask.bool() agent["category"][nearby_mask] = 3 agent["category"][obstacle_mask] = 0 def score_trained_vehicle(self, agent, max_num=10, min_distance=0): av_index = agent['av_index'] agent["category"] = torch.zeros_like(agent["category"]) pos = agent["position"][av_index, self.num_historical_steps, :2] distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1) distance_all_time = torch.norm(agent["position"][:, :, :2] - agent["position"][av_index, :, :2], dim=-1) invalid_mask = distance_all_time < 150 # we do not believe the perception out of range of 150 meters agent["valid_mask"] = agent["valid_mask"] * invalid_mask # we do not predict vehicle too far away from ego car closet_vehicle = distance < 100 valid = agent['valid_mask'] valid_current = valid[:, (self.num_historical_steps):] valid_counts = valid_current.sum(1) counts_vehicle = valid_counts >= 1 no_backgroud = agent['type'] != 3 vehicle2pred = closet_vehicle & counts_vehicle & no_backgroud if vehicle2pred.sum() > max_num: # too many still vehicle so that train the model using the moving vehicle as much as possible true_indices = torch.nonzero(vehicle2pred).squeeze(1) selected_indices = true_indices[torch.randperm(true_indices.size(0))[:max_num]] vehicle2pred.fill_(False) vehicle2pred[selected_indices] = True agent["category"][vehicle2pred] = 3 def rotate_agents(self, position, heading, num_nodes, num_historical_steps, num_future_steps): origin = position[:, num_historical_steps - 1] theta = heading[:, num_historical_steps - 1] cos, sin = theta.cos(), theta.sin() rot_mat = theta.new_zeros(num_nodes, 2, 2) rot_mat[:, 0, 0] = cos rot_mat[:, 0, 1] = -sin rot_mat[:, 1, 0] = sin rot_mat[:, 1, 1] = cos target = origin.new_zeros(num_nodes, num_future_steps, 4) target[..., :2] = torch.bmm(position[:, num_historical_steps:, :2] - origin[:, :2].unsqueeze(1), rot_mat) his = origin.new_zeros(num_nodes, num_historical_steps, 4) his[..., :2] = torch.bmm(position[:, :num_historical_steps, :2] - origin[:, :2].unsqueeze(1), rot_mat) if position.size(2) == 3: target[..., 2] = (position[:, num_historical_steps:, 2] - origin[:, 2].unsqueeze(-1)) his[..., 2] = (position[:, :num_historical_steps, 2] - origin[:, 2].unsqueeze(-1)) target[..., 3] = wrap_angle(heading[:, num_historical_steps:] - theta.unsqueeze(-1)) his[..., 3] = wrap_angle(heading[:, :num_historical_steps] - theta.unsqueeze(-1)) else: target[..., 2] = wrap_angle(heading[:, num_historical_steps:] - theta.unsqueeze(-1)) his[..., 2] = wrap_angle(heading[:, :num_historical_steps] - theta.unsqueeze(-1)) return his, target def __call__(self, data) -> HeteroData: agent = data["agent"] self.score_ego_agent(agent) self.score_trained_vehicle(agent, max_num=32) return HeteroData(data) ================================================ FILE: smart/utils/__init__.py ================================================ from smart.utils.geometry import angle_between_2d_vectors from smart.utils.geometry import angle_between_3d_vectors from smart.utils.geometry import side_to_directed_lineseg from smart.utils.geometry import wrap_angle from smart.utils.graph import add_edges from smart.utils.graph import bipartite_dense_to_sparse from smart.utils.graph import complete_graph from smart.utils.graph import merge_edges from smart.utils.graph import unbatch from smart.utils.list import safe_list_index from smart.utils.weight_init import weight_init ================================================ FILE: smart/utils/cluster_reader.py ================================================ import io import pickle import pandas as pd import json class LoadScenarioFromCeph: def __init__(self): from petrel_client.client import Client self.file_client = Client('~/petreloss.conf') def list(self, dir_path): return list(self.file_client.list(dir_path)) def save(self, data, url): self.file_client.put(url, pickle.dumps(data)) def read_correct_csv(self, scenario_path): output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python") return output def contains(self, url): return self.file_client.contains(url) def read_string(self, csv_url): from io import StringIO df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False) return df def read(self, scenario_path): with io.BytesIO(self.file_client.get(scenario_path)) as f: datas = pickle.load(f) return datas def read_json(self, path): with io.BytesIO(self.file_client.get(path)) as f: data = json.load(f) return data def read_csv(self, scenario_path): return pickle.loads(self.file_client.get(scenario_path)) def read_model(self, model_path): with io.BytesIO(self.file_client.get(model_path)) as f: pass ================================================ FILE: smart/utils/config.py ================================================ import os import yaml import easydict def load_config_act(path): """ load config file""" with open(path, 'r') as f: cfg = yaml.load(f, Loader=yaml.FullLoader) return easydict.EasyDict(cfg) def load_config_init(path): """ load config file""" path = os.path.join('init/configs', f'{path}.yaml') with open(path, 'r') as f: cfg = yaml.load(f, Loader=yaml.FullLoader) return cfg ================================================ FILE: smart/utils/geometry.py ================================================ import math import torch def angle_between_2d_vectors( ctr_vector: torch.Tensor, nbr_vector: torch.Tensor) -> torch.Tensor: return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0], (ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1)) def angle_between_3d_vectors( ctr_vector: torch.Tensor, nbr_vector: torch.Tensor) -> torch.Tensor: return torch.atan2(torch.cross(ctr_vector, nbr_vector, dim=-1).norm(p=2, dim=-1), (ctr_vector * nbr_vector).sum(dim=-1)) def side_to_directed_lineseg( query_point: torch.Tensor, start_point: torch.Tensor, end_point: torch.Tensor) -> str: cond = ((end_point[0] - start_point[0]) * (query_point[1] - start_point[1]) - (end_point[1] - start_point[1]) * (query_point[0] - start_point[0])) if cond > 0: return 'LEFT' elif cond < 0: return 'RIGHT' else: return 'CENTER' def wrap_angle( angle: torch.Tensor, min_val: float = -math.pi, max_val: float = math.pi) -> torch.Tensor: return min_val + (angle + max_val) % (max_val - min_val) ================================================ FILE: smart/utils/graph.py ================================================ from typing import List, Optional, Tuple, Union import torch from torch_geometric.utils import coalesce from torch_geometric.utils import degree def add_edges( from_edge_index: torch.Tensor, to_edge_index: torch.Tensor, from_edge_attr: Optional[torch.Tensor] = None, to_edge_attr: Optional[torch.Tensor] = None, replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype) mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) & (to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0))) if replace: to_mask = mask.any(dim=1) if from_edge_attr is not None and to_edge_attr is not None: from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0) to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1) else: from_mask = mask.any(dim=0) if from_edge_attr is not None and to_edge_attr is not None: from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0) to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1) return to_edge_index, to_edge_attr def merge_edges( edge_indices: List[torch.Tensor], edge_attrs: Optional[List[torch.Tensor]] = None, reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]: edge_index = torch.cat(edge_indices, dim=1) if edge_attrs is not None: edge_attr = torch.cat(edge_attrs, dim=0) else: edge_attr = None return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce) def complete_graph( num_nodes: Union[int, Tuple[int, int]], ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, loop: bool = False, device: Optional[Union[torch.device, str]] = None) -> torch.Tensor: if ptr is None: if isinstance(num_nodes, int): num_src, num_dst = num_nodes, num_nodes else: num_src, num_dst = num_nodes edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device), torch.arange(num_dst, dtype=torch.long, device=device)).t() else: if isinstance(ptr, torch.Tensor): ptr_src, ptr_dst = ptr, ptr num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1] else: ptr_src, ptr_dst = ptr num_src_batch = ptr_src[1:] - ptr_src[:-1] num_dst_batch = ptr_dst[1:] - ptr_dst[:-1] edge_index = torch.cat( [torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device), torch.arange(num_dst, dtype=torch.long, device=device)) + p for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))], dim=0) edge_index = edge_index.t() if isinstance(num_nodes, int) and not loop: edge_index = edge_index[:, edge_index[0] != edge_index[1]] return edge_index.contiguous() def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor: index = adj.nonzero(as_tuple=True) if len(index) == 3: batch_src = index[0] * adj.size(1) batch_dst = index[0] * adj.size(2) index = (batch_src + index[1], batch_dst + index[2]) return torch.stack(index, dim=0) def unbatch( src: torch.Tensor, batch: torch.Tensor, dim: int = 0) -> List[torch.Tensor]: sizes = degree(batch, dtype=torch.long).tolist() return src.split(sizes, dim) ================================================ FILE: smart/utils/list.py ================================================ from typing import Any, List, Optional def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]: try: return ls.index(elem) except ValueError: return None ================================================ FILE: smart/utils/log.py ================================================ import logging import time import os class Logging: def make_log_dir(self, dirname='logs'): now_dir = os.path.dirname(__file__) path = os.path.join(now_dir, dirname) path = os.path.normpath(path) if not os.path.exists(path): os.mkdir(path) return path def get_log_filename(self): filename = "{}.log".format(time.strftime("%Y-%m-%d",time.localtime())) filename = os.path.join(self.make_log_dir(), filename) filename = os.path.normpath(filename) return filename def log(self, level='DEBUG', name="simagent"): logger = logging.getLogger(name) level = getattr(logging, level) logger.setLevel(level) if not logger.handlers: sh = logging.StreamHandler() fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8") fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s") sh.setFormatter(fmt=fmt) fh.setFormatter(fmt=fmt) logger.addHandler(sh) logger.addHandler(fh) return logger def add_log(self, logger, level='DEBUG'): level = getattr(logging, level) logger.setLevel(level) if not logger.handlers: sh = logging.StreamHandler() fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8") fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s") sh.setFormatter(fmt=fmt) fh.setFormatter(fmt=fmt) logger.addHandler(sh) logger.addHandler(fh) return logger if __name__ == '__main__': logger = Logging().log(level='INFO') logger.debug("1111111111111111111111") #使用日志器生成日志 logger.info("222222222222222222222222") logger.error("附件为IP飞机外婆家二分IP文件放") logger.warning("3333333333333333333333333333") logger.critical("44444444444444444444444444") ================================================ FILE: smart/utils/nan_checker.py ================================================ import torch def check_nan_inf(t, s): assert not torch.isinf(t).any(), f"{s} is inf, {t}" assert not torch.isnan(t).any(), f"{s} is nan, {t}" ================================================ FILE: smart/utils/weight_init.py ================================================ import torch.nn as nn def weight_init(m: nn.Module) -> None: if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): fan_in = m.in_channels / m.groups fan_out = m.out_channels / m.groups bound = (6.0 / (fan_in + fan_out)) ** 0.5 nn.init.uniform_(m.weight, -bound, bound) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.MultiheadAttention): if m.in_proj_weight is not None: fan_in = m.embed_dim fan_out = m.embed_dim bound = (6.0 / (fan_in + fan_out)) ** 0.5 nn.init.uniform_(m.in_proj_weight, -bound, bound) else: nn.init.xavier_uniform_(m.q_proj_weight) nn.init.xavier_uniform_(m.k_proj_weight) nn.init.xavier_uniform_(m.v_proj_weight) if m.in_proj_bias is not None: nn.init.zeros_(m.in_proj_bias) nn.init.xavier_uniform_(m.out_proj.weight) if m.out_proj.bias is not None: nn.init.zeros_(m.out_proj.bias) if m.bias_k is not None: nn.init.normal_(m.bias_k, mean=0.0, std=0.02) if m.bias_v is not None: nn.init.normal_(m.bias_v, mean=0.0, std=0.02) elif isinstance(m, (nn.LSTM, nn.LSTMCell)): for name, param in m.named_parameters(): if 'weight_ih' in name: for ih in param.chunk(4, 0): nn.init.xavier_uniform_(ih) elif 'weight_hh' in name: for hh in param.chunk(4, 0): nn.init.orthogonal_(hh) elif 'weight_hr' in name: nn.init.xavier_uniform_(param) elif 'bias_ih' in name: nn.init.zeros_(param) elif 'bias_hh' in name: nn.init.zeros_(param) nn.init.ones_(param.chunk(4, 0)[1]) elif isinstance(m, (nn.GRU, nn.GRUCell)): for name, param in m.named_parameters(): if 'weight_ih' in name: for ih in param.chunk(3, 0): nn.init.xavier_uniform_(ih) elif 'weight_hh' in name: for hh in param.chunk(3, 0): nn.init.orthogonal_(hh) elif 'bias_ih' in name: nn.init.zeros_(param) elif 'bias_hh' in name: nn.init.zeros_(param) ================================================ FILE: train.py ================================================ from argparse import ArgumentParser import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.strategies import DDPStrategy from smart.utils.config import load_config_act from smart.datamodules import MultiDataModule from smart.model import SMART from smart.utils.log import Logging if __name__ == '__main__': parser = ArgumentParser() Predictor_hash = {"smart": SMART, } parser.add_argument('--config', type=str, default='configs/train/train_scalable.yaml') parser.add_argument('--pretrain_ckpt', type=str, default="") parser.add_argument('--ckpt_path', type=str, default="") parser.add_argument('--save_ckpt_path', type=str, default="") args = parser.parse_args() config = load_config_act(args.config) Predictor = Predictor_hash[config.Model.predictor] strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True) Data_config = config.Dataset datamodule = MultiDataModule(**vars(Data_config)) if args.pretrain_ckpt == "": model = Predictor(config.Model) else: logger = Logging().log(level='DEBUG') model = Predictor(config.Model) model.load_params_from_file(filename=args.pretrain_ckpt, logger=logger) trainer_config = config.Trainer model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path, filename="{epoch:02d}", monitor='val_cls_acc', every_n_epochs=1, save_top_k=5, mode='max') lr_monitor = LearningRateMonitor(logging_interval='epoch') trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=trainer_config.devices, strategy=strategy, accumulate_grad_batches=trainer_config.accumulate_grad_batches, num_nodes=trainer_config.num_nodes, callbacks=[model_checkpoint, lr_monitor], max_epochs=trainer_config.max_epochs, num_sanity_val_steps=0, gradient_clip_val=0.5) if args.ckpt_path == "": trainer.fit(model, datamodule) else: trainer.fit(model, datamodule, ckpt_path=args.ckpt_path) ================================================ FILE: val.py ================================================ from argparse import ArgumentParser import pytorch_lightning as pl from torch_geometric.loader import DataLoader from smart.datasets.scalable_dataset import MultiDataset from smart.model import SMART from smart.transforms import WaymoTargetBuilder from smart.utils.config import load_config_act from smart.utils.log import Logging if __name__ == '__main__': pl.seed_everything(2, workers=True) parser = ArgumentParser() parser.add_argument('--config', type=str, default="configs/validation/validation_scalable.yaml") parser.add_argument('--pretrain_ckpt', type=str, default="") parser.add_argument('--ckpt_path', type=str, default="") parser.add_argument('--save_ckpt_path', type=str, default="") args = parser.parse_args() config = load_config_act(args.config) data_config = config.Dataset val_dataset = { "scalable": MultiDataset, }[data_config.dataset](root=data_config.root, split='val', raw_dir=data_config.val_raw_dir, processed_dir=data_config.val_processed_dir, transform=WaymoTargetBuilder(config.Model.num_historical_steps, config.Model.decoder.num_future_steps)) dataloader = DataLoader(val_dataset, batch_size=data_config.batch_size, shuffle=False, num_workers=data_config.num_workers, pin_memory=data_config.pin_memory, persistent_workers=True if data_config.num_workers > 0 else False) Predictor = SMART if args.pretrain_ckpt == "": model = Predictor(config.Model) else: logger = Logging().log(level='DEBUG') model = Predictor(config.Model) model.load_params_from_file(filename=args.pretrain_ckpt, logger=logger) trainer_config = config.Trainer trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=trainer_config.devices, strategy='ddp', num_sanity_val_steps=0) trainer.validate(model, dataloader)