[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n.github\nckpt/\n# assets/\n# C extensions\n*.so\n# /assets\n/data\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\n*.jpg\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n*.jpg\npyg_depend/\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n# IDEs\n.idea\n.vscode\n\n# seed project\nav2/\nlightning_logs/\nlightning_logs_/\nlightning_l/\n.DS_Store\ndata/argo\ndata/res\ndata/waymo*\nfig*/\ndata/waymo_token\ndata/submission\ndata/token_seq_emb_nuplan\ndata/token_seq_emb_waymo\ndata/nuplan*\nsubmission.tar.gz\ndata/feat*\ndata/scalable\ndata/pos_data\nres_metrics*\ngathered*"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n  \n  # SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction\n  \n  [Paper](https://arxiv.org/abs/2405.15677) | [Webpage](https://smart-motion.github.io/smart/)\n\n</div>\n\n- **Ranked 1st** on the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/)  \n- **Champion** of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)\n\n## News\n- **[December 31, 2024]** SMART-Planner achieved state-of-the-art performance on **nuPlan closed-loop planning**\n- **[September 26, 2024]** SMART was **accepted to** NeurIPS 2024\n- **[August 31, 2024]** Code released\n- **[May 24, 2024]** SMART won the championship of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)\n- **[May 24, 2024]** SMART paper released on [arxiv](https://arxiv.org/abs/2405.15677)\n\n\n## Introduction\nThis repository contains the official implementation of SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction. SMART is a novel autonomous driving motion generation paradigm that models vectorized map and agent trajectory data into discrete sequence tokens.\n\nhttps://github.com/user-attachments/assets/74a61627-8444-4e54-bb10-d317dd2aacd9\n\n## Requirements\n\nTo set up the environment, you can use conda to create and activate a new environment with the necessary dependencies:\n\n```bash\nconda env create -f environment.yml\nconda activate SMART\npip install -r requirements.txt\n```\n\nIf you encounter issues while installing pyg dependencies, execute the following script:\n```setup\nbash install_pyg.sh\n```\n\nAlternatively, you can configure the environment in your preferred way. Installing the latest versions of PyTorch, PyG, and PyTorch Lightning should suffice.\n\n## Data installation\n\n**Step 1: Download the Dataset**\n\nDownload the Waymo Open Motion Dataset (`scenario protocol` format) and organize the data as follows:\n```\nSMART\n├── data\n│   ├── waymo\n│   │   ├── scenario\n│   │   │   ├──training\n│   │   │   ├──validation\n│   │   │   ├──testing\n├── model\n├── tools\n```\n\n**Step 2: Install the Waymo Open Dataset API**\n\nFollow the instructions [here](https://github.com/waymo-research/waymo-open-dataset) to install the Waymo Open Dataset API.\n\n**Step 3: Preprocess the Dataset**\n\nPreprocess the dataset by running:\n```\npython data_preprocess.py --input_dir ./data/waymo/scenario/training  --output_dir ./data/waymo_processed/training\n```\nThe first path is the raw data path, and the second is the output data path.\n\nThe processed data will be saved to the `data/waymo_processed/` directory as follows:\n\n```\nSMART\n├── data\n│   ├── waymo_processed\n│   │   ├── training\n│   │   ├── validation\n│   │   ├──testing\n├── model\n├── utils\n```\n\n## Training\n\nTo train the model, run the following command:\n\n```train\npython train.py --config ${config_path}\n```\n\nThe default config path is `configs/train/train_scalable.yaml`. Ensure you have downloaded and prepared the Waymo data for training.\n\n## Evaluation\n\nTo evaluate the model, run:\n\n```eval\npython eval.py --config ${config_path} --pretrain_ckpt ${ckpt_path}\n```\nThis will evaluate the model using the configuration and checkpoint provided.\n\n\n## Pre-trained Models\n\nTo comply with the WOMD participation agreement, we will release the model parameters of a medium-sized model not trained on Waymo data. Users can fine-tune this model with Waymo data as needed.\n\n## Results\n\n### Waymo Open Motion Dataset Sim Agents Challenge\n\nOur model achieves the following performance on the [Waymo Open Motion Dataset Sim Agents Challenge](https://waymo.com/open/challenges/2024/sim-agents/):\n\n| Model name    | Metric Score |\n| :-----------: | ------------ |\n| SMART-tiny    | 0.7591       |\n| SMART-large   | 0.7614       |\n| SMART-zeroshot| 0.7210       |\n\n### NuPlan Closed-loop Planning\n\n**SMART-Planner** achieved state-of-the-art performance among learning-based algorithms on **nuPlan closed-loop planning**. The results on val14 are shown below:\n\n![nuPlan Closed-loop Planning](assets/result1.png)\n\n## Citation \n\nIf you find this repository useful, please consider citing our work and giving us a star:\n\n```citation\n@article{wu2024smart,\n  title={SMART: Scalable Multi-agent Real-time Simulation via Next-token Prediction},\n  author={Wu, Wei and Feng, Xiaoxin and Gao, Ziyan and Kan, Yuheng},\n  journal={arXiv preprint arXiv:2405.15677},\n  year={2024}\n}\n```\n\n## Acknowledgements\nSpecial thanks to the [QCNET](https://github.com/ZikangZhou/QCNet) repository for providing valuable reference code that significantly influenced this work. \n\n## License\nAll code in this repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).\n"
  },
  {
    "path": "__init__.py",
    "content": "\n"
  },
  {
    "path": "configs/train/train_scalable.yaml",
    "content": "# Config format schema number, the yaml support to valid case source from different dataset\ntime_info: &time_info\n  num_historical_steps: 11\n  num_future_steps: 80\n  use_intention: True\n  token_size: 2048\n\nDataset:\n  root:\n  train_batch_size: 1\n  val_batch_size: 1\n  test_batch_size: 1\n  shuffle: True\n  num_workers: 1\n  pin_memory: True\n  persistent_workers: True\n  train_raw_dir: [\"data/valid_demo\"]\n  val_raw_dir: [\"data/valid_demo\"]\n  test_raw_dir:\n  transform: WaymoTargetBuilder\n  train_processed_dir:\n  val_processed_dir:\n  test_processed_dir:\n  dataset: \"scalable\"\n  <<: *time_info\n\nTrainer:\n  strategy: ddp_find_unused_parameters_false\n  accelerator: \"gpu\"\n  devices: 1\n  max_epochs: 32\n  save_ckpt_path:\n  num_nodes: 1\n  mode:\n  ckpt_path:\n  precision: 32\n  accumulate_grad_batches: 1\n\nModel:\n  mode: \"train\"\n  predictor: \"smart\"\n  dataset: \"waymo\"\n  input_dim: 2\n  hidden_dim: 128\n  output_dim: 2\n  output_head: False\n  num_heads: 8\n  <<: *time_info\n  head_dim: 16\n  dropout: 0.1\n  num_freq_bands: 64\n  lr: 0.0005\n  warmup_steps: 0\n  total_steps: 32\n  decoder:\n    <<: *time_info\n    num_map_layers: 3\n    num_agent_layers: 6\n    a2a_radius: 60\n    pl2pl_radius: 10\n    pl2a_radius: 30\n    time_span: 30\n"
  },
  {
    "path": "configs/validation/validation_scalable.yaml",
    "content": "# Config format schema number, the yaml support to valid case source from different dataset\ntime_info: &time_info\n  num_historical_steps: 11\n  num_future_steps: 80\n  token_size: 2048\n\nDataset:\n  root:\n  batch_size: 1\n  shuffle: True\n  num_workers: 1\n  pin_memory: True\n  persistent_workers: True\n  train_raw_dir:\n  val_raw_dir: [\"data/valid_demo\"]\n  test_raw_dir:\n  TargetBuilder: WaymoTargetBuilder\n  train_processed_dir:\n  val_processed_dir:\n  test_processed_dir:\n  dataset: \"scalable\"\n  <<: *time_info\n\nTrainer:\n  strategy: ddp_find_unused_parameters_false\n  accelerator: \"gpu\"\n  devices: 1\n  max_epochs: 32\n  save_ckpt_path: \n  num_nodes: 1\n  mode:\n  ckpt_path: \n  precision: 32\n  accumulate_grad_batches: 1\n\nModel:\n  mode: \"validation\"\n  predictor: \"smart\"\n  dataset: \"waymo\"\n  input_dim: 2\n  hidden_dim: 128\n  output_dim: 2\n  output_head: False\n  num_heads: 8\n  <<: *time_info\n  head_dim: 16\n  dropout: 0.1\n  num_freq_bands: 64\n  lr: 0.0005\n  warmup_steps: 0\n  total_steps: 32\n  decoder:\n    <<: *time_info\n    num_map_layers: 3\n    num_agent_layers: 6\n    a2a_radius: 60\n    pl2pl_radius: 10\n    pl2a_radius: 30\n    time_span: 30\n\n"
  },
  {
    "path": "data_preprocess.py",
    "content": "import numpy as np\nimport pandas as pd\nimport os\nimport torch\nimport pickle\nfrom tqdm import tqdm\nfrom typing import Any, Dict, List, Optional\nimport easydict\n\npredict_unseen_agents = False\nvector_repr = True\nroot = ''\nsplit = 'train'\nraw_dir = os.path.join(root, split, 'raw')\n_raw_dir = raw_dir\n\nif os.path.isdir(_raw_dir):\n    _raw_file_names = [name for name in os.listdir(_raw_dir)]\nelse:\n    _raw_file_names = []\n\nprocessed_dir = os.path.join(root, split, 'processed')\n_processed_dir = processed_dir\nif os.path.isdir(_processed_dir):\n    _processed_file_names = [name for name in os.listdir(_processed_dir) if\n                             name.endswith(('pkl', 'pickle'))]\nelse:\n    _processed_file_names = []\n\n_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']\n_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']\n_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']\n_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',\n                'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',\n                'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',\n                'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']\n_point_sides = ['LEFT', 'RIGHT', 'CENTER']\n_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']\n_polygon_is_intersections = [True, False, None]\n\n\nLane_type_hash = {\n    4: \"BIKE\",\n    3: \"VEHICLE\",\n    2: \"VEHICLE\",\n    1: \"BUS\"\n}\n\nboundary_type_hash = {\n        5: \"UNKNOWN\",\n        6: \"DASHED_WHITE\",\n        7: \"SOLID_WHITE\",\n        8: \"DOUBLE_DASH_WHITE\",\n        9: \"DASHED_YELLOW\",\n        10: \"DOUBLE_DASH_YELLOW\",\n        11: \"SOLID_YELLOW\",\n        12: \"DOUBLE_SOLID_YELLOW\",\n        13: \"DASH_SOLID_YELLOW\",\n        14: \"UNKNOWN\",\n        15: \"EDGE\",\n        16: \"EDGE\"\n}\n\n\ndef safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:\n    try:\n        return ls.index(elem)\n    except ValueError:\n        return None\n\n\ndef get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:\n    if not predict_unseen_agents:  # filter out agents that are unseen during the historical time steps\n        historical_df = df[df['timestep'] == num_historical_steps-1]\n        agent_ids = list(historical_df['track_id'].unique())\n        df = df[df['track_id'].isin(agent_ids)]\n    else:\n        agent_ids = list(df['track_id'].unique())\n\n    num_agents = len(agent_ids)\n    # initialization\n    valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)\n    current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)\n    predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)\n    agent_id: List[Optional[str]] = [None] * num_agents\n    agent_type = torch.zeros(num_agents, dtype=torch.uint8)\n    agent_category = torch.zeros(num_agents, dtype=torch.uint8)\n    position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)\n    heading = torch.zeros(num_agents, num_steps, dtype=torch.float)\n    velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)\n    shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)\n\n    for track_id, track_df in df.groupby('track_id'):\n        agent_idx = agent_ids.index(track_id)\n        agent_steps = track_df['timestep'].values\n\n        valid_mask[agent_idx, agent_steps] = True\n        current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]\n        predict_mask[agent_idx, agent_steps] = True\n        if vector_repr:  # a time step t is valid only when both t and t-1 are valid\n            valid_mask[agent_idx, 1: num_historical_steps] = (\n                valid_mask[agent_idx, :num_historical_steps - 1] &\n                valid_mask[agent_idx, 1: num_historical_steps])\n            valid_mask[agent_idx, 0] = False\n        predict_mask[agent_idx, :num_historical_steps] = False\n        if not current_valid_mask[agent_idx]:\n            predict_mask[agent_idx, num_historical_steps:] = False\n\n        agent_id[agent_idx] = track_id\n        agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])\n        agent_category[agent_idx] = track_df['object_category'].values[0]\n        position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,\n                                                                          track_df['position_y'].values,\n                                                                          track_df['position_z'].values],\n                                                                         axis=-1)).float()\n        heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()\n        velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,\n                                                                          track_df['velocity_y'].values],\n                                                                         axis=-1)).float()\n        shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,\n                                                                       track_df['width'].values,\n                                                                       track_df[\"height\"].values],\n                                                                      axis=-1)).float()\n    av_idx = agent_id.index(av_id)\n    if split == 'test':\n        predict_mask[current_valid_mask\n                     | (agent_category == 2)\n                     | (agent_category == 3), num_historical_steps:] = True\n\n    return {\n        'num_nodes': num_agents,\n        'av_index': av_idx,\n        'valid_mask': valid_mask,\n        'predict_mask': predict_mask,\n        'id': agent_id,\n        'type': agent_type,\n        'category': agent_category,\n        'position': position,\n        'heading': heading,\n        'velocity': velocity,\n        'shape': shape\n    }\n\n\ndef get_map_features(map_infos, tf_current_light, dim=3):\n    lane_segments = map_infos['lane']\n    all_polylines = map_infos[\"all_polylines\"]\n    crosswalks = map_infos['crosswalk']\n    road_edges = map_infos['road_edge']\n    road_lines = map_infos['road_line']\n    lane_segment_ids = [info[\"id\"] for info in lane_segments]\n    cross_walk_ids = [info[\"id\"] for info in crosswalks]\n    road_edge_ids = [info[\"id\"] for info in road_edges]\n    road_line_ids = [info[\"id\"] for info in road_lines]\n    polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids\n    num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids)\n\n    # initialization\n    polygon_type = torch.zeros(num_polygons, dtype=torch.uint8)\n    polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3\n\n    point_position: List[Optional[torch.Tensor]] = [None] * num_polygons\n    point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons\n    point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons\n    point_height: List[Optional[torch.Tensor]] = [None] * num_polygons\n    point_type: List[Optional[torch.Tensor]] = [None] * num_polygons\n\n    for lane_segment in lane_segments:\n        lane_segment = easydict.EasyDict(lane_segment)\n        lane_segment_idx = polygon_ids.index(lane_segment.id)\n        polyline_index = lane_segment.polyline_index\n        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]\n        centerline = torch.from_numpy(centerline).float()\n        polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type])\n\n        res = tf_current_light[tf_current_light[\"lane_id\"] == str(lane_segment.id)]\n        if len(res) != 0:\n            polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res[\"state\"].item())\n\n        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)\n        center_vectors = centerline[1:] - centerline[:-1]\n        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)\n        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)\n        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)\n        center_type = _point_types.index('CENTERLINE')\n        point_type[lane_segment_idx] = torch.cat(\n            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)\n\n    for lane_segment in road_edges:\n        lane_segment = easydict.EasyDict(lane_segment)\n        lane_segment_idx = polygon_ids.index(lane_segment.id)\n        polyline_index = lane_segment.polyline_index\n        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]\n        centerline = torch.from_numpy(centerline).float()\n        polygon_type[lane_segment_idx] = _polygon_types.index(\"VEHICLE\")\n\n        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)\n        center_vectors = centerline[1:] - centerline[:-1]\n        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)\n        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)\n        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)\n        center_type = _point_types.index('EDGE')\n        point_type[lane_segment_idx] = torch.cat(\n            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)\n\n    for lane_segment in road_lines:\n        lane_segment = easydict.EasyDict(lane_segment)\n        lane_segment_idx = polygon_ids.index(lane_segment.id)\n        polyline_index = lane_segment.polyline_index\n        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]\n        centerline = torch.from_numpy(centerline).float()\n\n        polygon_type[lane_segment_idx] = _polygon_types.index(\"VEHICLE\")\n\n        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)\n        center_vectors = centerline[1:] - centerline[:-1]\n        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)\n        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)\n        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)\n        center_type = _point_types.index(boundary_type_hash[lane_segment.type])\n        point_type[lane_segment_idx] = torch.cat(\n            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)\n\n    for crosswalk in crosswalks:\n        crosswalk = easydict.EasyDict(crosswalk)\n        lane_segment_idx = polygon_ids.index(crosswalk.id)\n        polyline_index = crosswalk.polyline_index\n        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]\n        centerline = torch.from_numpy(centerline).float()\n\n        polygon_type[lane_segment_idx] = _polygon_types.index(\"PEDESTRIAN\")\n\n        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)\n        center_vectors = centerline[1:] - centerline[:-1]\n        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)\n        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)\n        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)\n        center_type = _point_types.index(\"CROSSWALK\")\n        point_type[lane_segment_idx] = torch.cat(\n            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)\n\n    num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long)\n    point_to_polygon_edge_index = torch.stack(\n        [torch.arange(num_points.sum(), dtype=torch.long),\n            torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0)\n    polygon_to_polygon_edge_index = []\n    polygon_to_polygon_type = []\n    for lane_segment in lane_segments:\n        lane_segment = easydict.EasyDict(lane_segment)\n        lane_segment_idx = polygon_ids.index(lane_segment.id)\n        pred_inds = []\n        for pred in lane_segment.entry_lanes:\n            pred_idx = safe_list_index(polygon_ids, pred)\n            if pred_idx is not None:\n                pred_inds.append(pred_idx)\n        if len(pred_inds) != 0:\n            polygon_to_polygon_edge_index.append(\n                torch.stack([torch.tensor(pred_inds, dtype=torch.long),\n                             torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0))\n            polygon_to_polygon_type.append(\n                torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8))\n        succ_inds = []\n        for succ in lane_segment.exit_lanes:\n            succ_idx = safe_list_index(polygon_ids, succ)\n            if succ_idx is not None:\n                succ_inds.append(succ_idx)\n        if len(succ_inds) != 0:\n            polygon_to_polygon_edge_index.append(\n                torch.stack([torch.tensor(succ_inds, dtype=torch.long),\n                             torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0))\n            polygon_to_polygon_type.append(\n                torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8))\n        if len(lane_segment.left_neighbors) != 0:\n            left_neighbor_ids = lane_segment.left_neighbors\n            for left_neighbor_id in left_neighbor_ids:\n                left_idx = safe_list_index(polygon_ids, left_neighbor_id)\n                if left_idx is not None:\n                    polygon_to_polygon_edge_index.append(\n                        torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long))\n                    polygon_to_polygon_type.append(\n                        torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8))\n        if len(lane_segment.right_neighbors) != 0:\n            right_neighbor_ids = lane_segment.right_neighbors\n            for right_neighbor_id in right_neighbor_ids:\n                right_idx = safe_list_index(polygon_ids, right_neighbor_id)\n                if right_idx is not None:\n                    polygon_to_polygon_edge_index.append(\n                        torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long))\n                    polygon_to_polygon_type.append(\n                        torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8))\n    if len(polygon_to_polygon_edge_index) != 0:\n        polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1)\n        polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0)\n    else:\n        polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long)\n        polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8)\n\n    map_data = {\n        'map_polygon': {},\n        'map_point': {},\n        ('map_point', 'to', 'map_polygon'): {},\n        ('map_polygon', 'to', 'map_polygon'): {},\n    }\n    map_data['map_polygon']['num_nodes'] = num_polygons\n    map_data['map_polygon']['type'] = polygon_type\n    map_data['map_polygon']['light_type'] = polygon_light_type\n    if len(num_points) == 0:\n        map_data['map_point']['num_nodes'] = 0\n        map_data['map_point']['position'] = torch.tensor([], dtype=torch.float)\n        map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float)\n        map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float)\n        if dim == 3:\n            map_data['map_point']['height'] = torch.tensor([], dtype=torch.float)\n        map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8)\n        map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8)\n    else:\n        map_data['map_point']['num_nodes'] = num_points.sum().item()\n        map_data['map_point']['position'] = torch.cat(point_position, dim=0)\n        map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0)\n        map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0)\n        if dim == 3:\n            map_data['map_point']['height'] = torch.cat(point_height, dim=0)\n        map_data['map_point']['type'] = torch.cat(point_type, dim=0)\n    map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index\n    map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index\n    map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type\n    # import matplotlib.pyplot as plt\n    # plt.axis('equal')\n    # plt.scatter(map_data['map_point']['position'][:, 0],\n    #             map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')\n    # plt.show(dpi=600)\n    return map_data\n\n\ndef process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, start_timestamp, end_timestamp):\n    agents_array = track_info[\"trajs\"].transpose(1, 0, 2)\n    object_id = np.array(track_info[\"object_id\"])\n    object_type = track_info[\"object_type\"]\n    id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))}\n    def type_hash(x):\n        tp = id_hash[x]\n        type_re_hash = {\n            \"TYPE_VEHICLE\": \"vehicle\",\n            \"TYPE_PEDESTRIAN\": \"pedestrian\",\n            \"TYPE_CYCLIST\": \"cyclist\",\n            \"TYPE_OTHER\": \"background\",\n            \"TYPE_UNSET\": \"background\"\n        }\n        return type_re_hash[tp]\n\n    columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep',\n               'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y',\n               'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps',\n               'focal_track_id', 'city']\n    new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11))\n    new_columns[:11, :, 0] = True\n    new_columns[11:, :, 0] = False\n    for index in range(new_columns.shape[0]):\n        new_columns[index, :, 4] = int(index)\n    new_columns[..., 1] = object_id\n    new_columns[..., 2] = object_id\n    new_columns[:, tracks_to_predict[\"track_index\"], 3] = 3\n    new_columns[..., 5] = 11\n    new_columns[..., 6] = int(start_timestamp)\n    new_columns[..., 7] = int(end_timestamp)\n    new_columns[..., 8] = int(91)\n    new_columns[..., 9] = object_id\n    new_columns[..., 10] = 10086\n    new_columns = new_columns\n    new_agents_array = np.concatenate([new_columns, agents_array], axis=-1)\n    new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1])\n    new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10]]\n    new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns)\n    new_agents_array[\"object_type\"] = new_agents_array[\"object_type\"].apply(func=type_hash)\n    new_agents_array[\"start_timestamp\"] = new_agents_array[\"start_timestamp\"].astype(int)\n    new_agents_array[\"end_timestamp\"] = new_agents_array[\"end_timestamp\"].astype(int)\n    new_agents_array[\"num_timestamps\"] = new_agents_array[\"num_timestamps\"].astype(int)\n    new_agents_array[\"scenario_id\"] = scenario_id\n    return new_agents_array\n\n\ndef process_dynamic_map(dynamic_map_infos):\n    lane_ids = dynamic_map_infos[\"lane_id\"]\n    tf_lights = []\n    for t in range(len(lane_ids)):\n        lane_id = lane_ids[t]\n        time = np.ones_like(lane_id) * t\n        state = dynamic_map_infos[\"state\"][t]\n        tf_light = np.concatenate([lane_id, time, state], axis=0)\n        tf_lights.append(tf_light)\n    tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0)\n    tf_lights = pd.DataFrame(data=tf_lights, columns=[\"lane_id\", \"time_step\", \"state\"])\n    tf_lights[\"time_step\"] = tf_lights[\"time_step\"].astype(\"str\")\n    tf_lights[\"lane_id\"] = tf_lights[\"lane_id\"].astype(\"str\")\n    tf_lights[\"state\"] = tf_lights[\"state\"].astype(\"str\")\n    tf_lights.loc[tf_lights[\"state\"].str.contains(\"STOP\"), [\"state\"] ] = 'LANE_STATE_STOP'\n    tf_lights.loc[tf_lights[\"state\"].str.contains(\"GO\"), [\"state\"] ] = 'LANE_STATE_GO'\n    tf_lights.loc[tf_lights[\"state\"].str.contains(\"CAUTION\"), [\"state\"] ] = 'LANE_STATE_CAUTION'\n    return tf_lights\n\n\npolyline_type = {\n    # for lane\n    'TYPE_UNDEFINED': -1,\n    'TYPE_FREEWAY': 1,\n    'TYPE_SURFACE_STREET': 2,\n    'TYPE_BIKE_LANE': 3,\n\n    # for roadline\n    'TYPE_UNKNOWN': -1,\n    'TYPE_BROKEN_SINGLE_WHITE': 6,\n    'TYPE_SOLID_SINGLE_WHITE': 7,\n    'TYPE_SOLID_DOUBLE_WHITE': 8,\n    'TYPE_BROKEN_SINGLE_YELLOW': 9,\n    'TYPE_BROKEN_DOUBLE_YELLOW': 10,\n    'TYPE_SOLID_SINGLE_YELLOW': 11,\n    'TYPE_SOLID_DOUBLE_YELLOW': 12,\n    'TYPE_PASSING_DOUBLE_YELLOW': 13,\n\n    # for roadedge\n    'TYPE_ROAD_EDGE_BOUNDARY': 15,\n    'TYPE_ROAD_EDGE_MEDIAN': 16,\n\n    # for stopsign\n    'TYPE_STOP_SIGN': 17,\n\n    # for crosswalk\n    'TYPE_CROSSWALK': 18,\n\n    # for speed bump\n    'TYPE_SPEED_BUMP': 19\n}\n\nobject_type = {\n    0: 'TYPE_UNSET',\n    1: 'TYPE_VEHICLE',\n    2: 'TYPE_PEDESTRIAN',\n    3: 'TYPE_CYCLIST',\n    4: 'TYPE_OTHER'\n}\n\n\nsignal_state = {\n    0: 'LANE_STATE_UNKNOWN',\n\n    # // States for traffic signals with arrows.\n    1: 'LANE_STATE_ARROW_STOP',\n    2: 'LANE_STATE_ARROW_CAUTION',\n    3: 'LANE_STATE_ARROW_GO',\n\n    # // Standard round traffic signals.\n    4: 'LANE_STATE_STOP',\n    5: 'LANE_STATE_CAUTION',\n    6: 'LANE_STATE_GO',\n\n    # // Flashing light signals.\n    7: 'LANE_STATE_FLASHING_STOP',\n    8: 'LANE_STATE_FLASHING_CAUTION'\n}\n\nsignal_state_to_id = {}\nfor key, val in signal_state.items():\n    signal_state_to_id[val] = key\n\n\ndef decode_tracks_from_proto(tracks):\n    track_infos = {\n        'object_id': [],  # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others}\n        'object_type': [],\n        'trajs': []\n    }\n    for cur_data in tracks:  # number of objects\n        cur_traj = [np.array([x.center_x, x.center_y, x.center_z, x.length, x.width, x.height, x.heading,\n                              x.velocity_x, x.velocity_y, x.valid], dtype=np.float32) for x in cur_data.states]\n        cur_traj = np.stack(cur_traj, axis=0)  # (num_timestamp, 10)\n\n        track_infos['object_id'].append(cur_data.id)\n        track_infos['object_type'].append(object_type[cur_data.object_type])\n        track_infos['trajs'].append(cur_traj)\n\n    track_infos['trajs'] = np.stack(track_infos['trajs'], axis=0)  # (num_objects, num_timestamp, 9)\n    return track_infos\n\n\nfrom collections import defaultdict\n\n\ndef decode_map_features_from_proto(map_features):\n    map_infos = {\n        'lane': [],\n        'road_line': [],\n        'road_edge': [],\n        'stop_sign': [],\n        'crosswalk': [],\n        'speed_bump': [],\n        'lane_dict': {},\n        'lane2other_dict': {}\n    }\n    polylines = []\n\n    point_cnt = 0\n    lane2other_dict = defaultdict(list)\n\n    for cur_data in map_features:\n        cur_info = {'id': cur_data.id}\n\n        if cur_data.lane.ByteSize() > 0:\n            cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph\n            cur_info['type'] = cur_data.lane.type + 1  # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane\n            cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors]\n\n            cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors]\n\n            cur_info['interpolating'] = cur_data.lane.interpolating\n            cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes)\n            cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes)\n\n            cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries]\n            cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries]\n\n            cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries]\n            cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries]\n            cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries]\n            cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries]\n            cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries]\n            cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries]\n\n            lane2other_dict[cur_data.id].extend(cur_info['left_boundary'])\n            lane2other_dict[cur_data.id].extend(cur_info['right_boundary'])\n\n            global_type = cur_info['type']\n            cur_polyline = np.stack(\n                [np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline],\n                axis=0)\n            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)\n            if cur_polyline.shape[0] <= 1:\n                continue\n            map_infos['lane'].append(cur_info)\n            map_infos['lane_dict'][cur_data.id] = cur_info\n\n        elif cur_data.road_line.ByteSize() > 0:\n            cur_info['type'] = cur_data.road_line.type + 5\n\n            global_type = cur_info['type']\n            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in\n                                     cur_data.road_line.polyline], axis=0)\n            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)\n            if cur_polyline.shape[0] <= 1:\n                continue\n            map_infos['road_line'].append(cur_info)\n\n        elif cur_data.road_edge.ByteSize() > 0:\n            cur_info['type'] = cur_data.road_edge.type + 14\n\n            global_type = cur_info['type']\n            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in\n                                     cur_data.road_edge.polyline], axis=0)\n            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)\n            if cur_polyline.shape[0] <= 1:\n                continue\n            map_infos['road_edge'].append(cur_info)\n\n        elif cur_data.stop_sign.ByteSize() > 0:\n            cur_info['lane_ids'] = list(cur_data.stop_sign.lane)\n            for i in cur_info['lane_ids']:\n                lane2other_dict[i].append(cur_data.id)\n            point = cur_data.stop_sign.position\n            cur_info['position'] = np.array([point.x, point.y, point.z])\n\n            global_type = polyline_type['TYPE_STOP_SIGN']\n            cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5)\n            if cur_polyline.shape[0] <= 1:\n                continue\n            map_infos['stop_sign'].append(cur_info)\n        elif cur_data.crosswalk.ByteSize() > 0:\n            global_type = polyline_type['TYPE_CROSSWALK']\n            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in\n                                     cur_data.crosswalk.polygon], axis=0)\n            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)\n            if cur_polyline.shape[0] <= 1:\n                continue\n            map_infos['crosswalk'].append(cur_info)\n\n        elif cur_data.speed_bump.ByteSize() > 0:\n            global_type = polyline_type['TYPE_SPEED_BUMP']\n            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in\n                                     cur_data.speed_bump.polygon], axis=0)\n            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)\n            if cur_polyline.shape[0] <= 1:\n                continue\n            map_infos['speed_bump'].append(cur_info)\n\n        else:\n            # print(cur_data)\n            continue\n        polylines.append(cur_polyline)\n        cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline))\n        point_cnt += len(cur_polyline)\n\n    # try:\n    polylines = np.concatenate(polylines, axis=0).astype(np.float32)\n    # except:\n    #     polylines = np.zeros((0, 8), dtype=np.float32)\n    #     print('Empty polylines: ')\n    map_infos['all_polylines'] = polylines\n    map_infos['lane2other_dict'] = lane2other_dict\n    return map_infos\n\n\ndef decode_dynamic_map_states_from_proto(dynamic_map_states):\n    dynamic_map_infos = {\n        'lane_id': [],\n        'state': [],\n        'stop_point': []\n    }\n    for cur_data in dynamic_map_states:  # (num_timestamp)\n        lane_id, state, stop_point = [], [], []\n        for cur_signal in cur_data.lane_states:  # (num_observed_signals)\n            lane_id.append(cur_signal.lane)\n            state.append(signal_state[cur_signal.state])\n            stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z])\n\n        dynamic_map_infos['lane_id'].append(np.array([lane_id]))\n        dynamic_map_infos['state'].append(np.array([state]))\n        dynamic_map_infos['stop_point'].append(np.array([stop_point]))\n\n    return dynamic_map_infos\n\n\ndef process_single_data(scenario):\n    info = {}\n    info['scenario_id'] = scenario.scenario_id\n    info['timestamps_seconds'] = list(scenario.timestamps_seconds)  # list of int of shape (91)\n    info['current_time_index'] = scenario.current_time_index  # int, 10\n    info['sdc_track_index'] = scenario.sdc_track_index  # int\n    info['objects_of_interest'] = list(scenario.objects_of_interest)  # list, could be empty list\n\n    info['tracks_to_predict'] = {\n        'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict],\n        'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict]\n    }  # for training: suggestion of objects to train on, for val/test: need to be predicted\n\n    track_infos = decode_tracks_from_proto(scenario.tracks)\n    info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in\n                                                info['tracks_to_predict']['track_index']]\n\n    # decode map related data\n    map_infos = decode_map_features_from_proto(scenario.map_features)\n    dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states)\n\n    save_infos = {\n        'track_infos': track_infos,\n        'dynamic_map_infos': dynamic_map_infos,\n        'map_infos': map_infos\n    }\n    save_infos.update(info)\n    return save_infos\n\nimport tensorflow as tf\nfrom waymo_open_dataset.protos import scenario_pb2\n\n\ndef wm2argo(file, dir_name, output_dir):\n    file_path = os.path.join(dir_name, file)\n    dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3)\n    for cnt, data in enumerate(dataset):\n        print(cnt)\n        scenario = scenario_pb2.Scenario()\n        scenario.ParseFromString(bytearray(data.numpy()))\n        save_infos = process_single_data(scenario) # pkl2mtr\n        map_info = save_infos[\"map_infos\"]\n        track_info = save_infos['track_infos']\n        scenario_id = save_infos['scenario_id']\n        tracks_to_predict = save_infos['tracks_to_predict']\n        sdc_track_index = save_infos['sdc_track_index']\n        av_id = track_info[\"object_id\"][sdc_track_index]\n        if len(tracks_to_predict[\"track_index\"]) < 1:\n            return\n        dynamic_map_infos = save_infos[\"dynamic_map_infos\"]\n        tf_lights = process_dynamic_map(dynamic_map_infos)\n        tf_current_light = tf_lights.loc[tf_lights[\"time_step\"] == \"11\"]\n        map_data = get_map_features(map_info, tf_current_light)\n        new_agents_array = process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, 0, 91) # mtr2argo\n        data = dict()\n        data['scenario_id'] = new_agents_array['scenario_id'].values[0]\n        data['city'] = new_agents_array['city'].values[0]\n        data['agent'] = get_agent_features(new_agents_array, av_id, num_historical_steps=11)\n        data.update(map_data)\n        with open(os.path.join(output_dir, scenario_id + '.pkl'), \"wb+\") as f:\n            pickle.dump(data, f)\n\n\ndef batch_process9s_transformer(dir_name, output_dir, num_workers=2):\n    from functools import partial\n    import multiprocessing\n    packages = os.listdir(dir_name)\n    func = partial(\n        wm2argo, output_dir=output_dir, dir_name=dir_name)\n    with multiprocessing.Pool(num_workers) as p:\n        list(tqdm(p.imap(func, packages), total=len(packages)))\n\n\nfrom argparse import ArgumentParser\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument('--input_dir', type=str, default='data/waymo/scenario/training')\n    parser.add_argument('--output_dir', type=str, default='data/waymo_processed/training')\n    args = parser.parse_args()\n    files = os.listdir(args.input_dir)\n    for file in tqdm(files):\n        wm2argo(file, args.input_dir, args.output_dir)\n    # batch_process9s_transformer(args.input_dir, args.output_dir, num_workers=\"ur_cpu_count\")\n"
  },
  {
    "path": "environment.yml",
    "content": "name: smart\nchannels:\n  - pytorch\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=5.1=1_gnu\n  - blas=1.0=mkl\n  - brotli-python=1.0.9=py39h6a678d5_8\n  - bzip2=1.0.8=h5eee18b_6\n  - ca-certificates=2024.9.24=h06a4308_0\n  - certifi=2024.8.30=py39h06a4308_0\n  - charset-normalizer=3.3.2=pyhd3eb1b0_0\n  - cudatoolkit=11.3.1=h2bc3f7f_2\n  - ffmpeg=4.3=hf484d3e_0\n  - freetype=2.12.1=h4a9f257_0\n  - gmp=6.2.1=h295c915_3\n  - gnutls=3.6.15=he1e5248_0\n  - idna=3.7=py39h06a4308_0\n  - intel-openmp=2023.1.0=hdb19cb5_46306\n  - jpeg=9e=h5eee18b_3\n  - lame=3.100=h7b6447c_0\n  - lcms2=2.12=h3be6417_0\n  - ld_impl_linux-64=2.40=h12ee557_0\n  - lerc=3.0=h295c915_0\n  - libdeflate=1.17=h5eee18b_1\n  - libffi=3.4.4=h6a678d5_1\n  - libgcc-ng=11.2.0=h1234567_1\n  - libgomp=11.2.0=h1234567_1\n  - libiconv=1.14=0\n  - libidn2=2.3.4=h5eee18b_0\n  - libpng=1.6.39=h5eee18b_0\n  - libstdcxx-ng=11.2.0=h1234567_1\n  - libtasn1=4.19.0=h5eee18b_0\n  - libtiff=4.5.1=h6a678d5_0\n  - libunistring=0.9.10=h27cfd23_0\n  - libwebp-base=1.3.2=h5eee18b_1\n  - lz4-c=1.9.4=h6a678d5_1\n  - mkl=2023.1.0=h213fc3f_46344\n  - mkl-service=2.4.0=py39h5eee18b_1\n  - mkl_fft=1.3.10=py39h5eee18b_0\n  - mkl_random=1.2.7=py39h1128e8f_0\n  - ncurses=6.4=h6a678d5_0\n  - nettle=3.7.3=hbbd107a_1\n  - openh264=2.1.1=h4ff587b_0\n  - openjpeg=2.5.2=he7f1fd0_0\n  - openssl=3.0.15=h5eee18b_0\n  - pillow=10.4.0=py39h5eee18b_0\n  - pip=24.2=py39h06a4308_0\n  - pysocks=1.7.1=py39h06a4308_0\n  - python=3.9.19=h955ad1f_1\n  - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0\n  - pytorch-mutex=1.0=cuda\n  - readline=8.2=h5eee18b_0\n  - requests=2.32.3=py39h06a4308_0\n  - setuptools=75.1.0=py39h06a4308_0\n  - sqlite=3.45.3=h5eee18b_0\n  - tbb=2021.8.0=hdb19cb5_0\n  - tk=8.6.14=h39e8969_0\n  - torchvision=0.13.1=py39_cu113\n  - typing_extensions=4.11.0=py39h06a4308_0\n  - urllib3=2.2.3=py39h06a4308_0\n  - wheel=0.44.0=py39h06a4308_0\n  - xz=5.4.6=h5eee18b_1\n  - zlib=1.2.13=h5eee18b_1\n  - zstd=1.5.6=hc292b87_0\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=42\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"smart\"\nversion = \"0.0.0\"\ndescription = \"Scalable Multi-agent Real-time Motion Generation via Next-token Prediction\"\nreadme = \"README.md\"\nauthors = [\n    {name = \"Xiaoxin Feng\"},\n    {name = \"Ziyan Gao\"},\n    {name = \"Yuheng Kan\"}\n]\nclassifiers = [\n    \"Programming Language :: Python :: 3\",\n    \"License :: OSI Approved :: Apache Software License\",\n    \"Operating System :: OS Independent\",\n]\nrequires-python = \">=3.9\"\ndependencies = [\n    \"easydict\",\n    \"numpy\",\n    \"pandas\",\n    \"pytorch-lightning\",\n    \"scipy\",\n    \"torch-cluster\",\n    \"torch-geometric\",\n    \"torch-scatter\",\n    \"torch\",\n    \"torchmetrics\",\n    \"tqdm\",\n]\n\n[project.urls]\n\"Homepage\" = \"https://smart-motion.github.io/smart/\"\n\"Repository\" = \"https://github.com/rainmaker22/SMART\"\n\"Paper\" = \"https://arxiv.org/abs/2405.15677\"\n\n[tool.setuptools]\npackages = [\"smart\"]\n"
  },
  {
    "path": "requirements.txt",
    "content": "aiohappyeyeballs==2.4.3\naiohttp==3.10.10\naiosignal==1.3.1\nasync-timeout==4.0.3\nattrs==24.2.0\ncontourpy==1.3.0\ncycler==0.12.1\neasydict==1.13\nfonttools==4.54.1\nfrozenlist==1.4.1\nfsspec==2024.10.0\nimportlib-resources==6.4.5\njinja2==3.1.4\nkiwisolver==1.4.7\nlightning-utilities==0.11.8\nmarkupsafe==3.0.2\nmatplotlib==3.9.2\nmultidict==6.1.0\nnumpy==1.26.4\npackaging==24.1\npandas==2.0.3\npropcache==0.2.0\npsutil==6.1.0\npyparsing==3.2.0\npython-dateutil==2.9.0.post0\npytorch-lightning==2.0.3\npytz==2024.2\npyyaml==6.0.1\nscipy==1.10.1\nshapely==2.0.6\nsix==1.16.0\ntorch-cluster==1.6.0+pt112cu113\ntorch-geometric==2.6.1\ntorch-scatter==2.1.0+pt112cu113\ntorch-sparse==0.6.16+pt112cu113\ntorch-spline-conv==1.2.1+pt112cu113\ntorchmetrics==1.5.0\ntqdm==4.66.5\ntzdata==2024.2\nyarl==1.16.0\nzipp==3.20.2\nwaymo-open-dataset-tf-2-12-0==1.6.4\n"
  },
  {
    "path": "scripts/install_pyg.sh",
    "content": "mkdir pyg_depend && cd pyg_depend\nwget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl\nwget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl\nwget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.16%2Bpt112cu113-cp39-cp39-linux_x86_64.whl\nwget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1%2Bpt112cu113-cp39-cp39-linux_x86_64.whl\npython3 -m pip install torch_cluster-1.6.0+pt112cu113-cp39-cp39-linux_x86_64.whl\npython3 -m pip install torch_scatter-2.1.0+pt112cu113-cp39-cp39-linux_x86_64.whl\npython3 -m pip install torch_sparse-0.6.16+pt112cu113-cp39-cp39-linux_x86_64.whl\npython3 -m pip install torch_spline_conv-1.2.1+pt112cu113-cp39-cp39-linux_x86_64.whl\npython3 -m pip install torch_geometric\n"
  },
  {
    "path": "scripts/traj_clstering.py",
    "content": "from smart.utils.geometry import wrap_angle\nimport numpy as np\n\n\ndef average_distance_vectorized(point_set1, centroids):\n    dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :])**2, axis=-1))\n    return np.mean(dists, axis=2)\n\n\ndef assign_clusters(sub_X, centroids):\n    distances = average_distance_vectorized(sub_X, centroids)\n    return np.argmin(distances, axis=1)\n\n\ndef Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None):\n    S = []\n    ret_traj_list = []\n    while len(S) < N:\n        num_all = X.shape[0]\n        # 随机选择第一个簇中心\n        choice_index = np.random.choice(num_all)\n        x0 = X[choice_index]\n        if x0[0, 0] < -10 or x0[0, 0] > 50 or x0[0, 1] > 10 or x0[0, 1] < -10:\n            continue\n        res_mask = np.sum((X - x0)**2, axis=(1, 2))/4 > (tol**2)\n        del_mask = np.sum((X - x0)**2, axis=(1, 2))/4 <= (tol**2)\n        if cal_mean_heading:\n            del_contour = X[del_mask]\n            diff_xy = del_contour[:, 0, :] - del_contour[:, 3, :]\n            del_heading = np.arctan2(diff_xy[:, 1], diff_xy[:, 0]).mean()\n            x0 = cal_polygon_contour(x0.mean(0)[0], x0.mean(0)[1], del_heading, width, length)\n            del_traj = a_pos[del_mask]\n            ret_traj = del_traj.mean(0)[None, ...]\n            if abs(ret_traj[0, 1, 0] - ret_traj[0, 0, 0]) > 1 and ret_traj[0, 1, 0] < 0:\n                print(ret_traj)\n                print('1')\n        else:\n            x0 = x0[None, ...]\n            ret_traj = a_pos[choice_index][None, ...]\n        X = X[res_mask]\n        a_pos = a_pos[res_mask]\n        S.append(x0)\n        ret_traj_list.append(ret_traj)\n    centroids = np.concatenate(S, axis=0)\n    ret_traj = np.concatenate(ret_traj_list, axis=0)\n\n    # closest_dist_sq = np.sum((X - centroids[0])**2, axis=(1, 2))\n\n    # for k in range(1, K):\n    #     new_dist_sq = np.sum((X - centroids[k - 1])**2, axis=(1, 2))\n    #     closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq)\n    #     probabilities = closest_dist_sq / np.sum(closest_dist_sq)\n    #     centroids[k] = X[np.random.choice(N, p=probabilities)]\n\n    return centroids, ret_traj\n\n\ndef cal_polygon_contour(x, y, theta, width, length):\n\n    left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)\n    left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)\n    left_front = np.column_stack((left_front_x, left_front_y))\n\n    right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)\n    right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)\n    right_front = np.column_stack((right_front_x, right_front_y))\n\n    right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)\n    right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)\n    right_back = np.column_stack((right_back_x, right_back_y))\n\n    left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)\n    left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)\n    left_back = np.column_stack((left_back_x, left_back_y))\n\n    polygon_contour = np.concatenate((left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)\n\n    return polygon_contour\n\n\nif __name__ == '__main__':\n    shift = 5 # motion token time dimension\n    num_cluster = 6 # vocabulary size\n    cal_mean_heading = True\n    data = {\n        \"veh\": np.random.rand(1000, 6, 3),\n        \"cyc\": np.random.rand(1000, 6, 3),\n        \"ped\": np.random.rand(1000, 6, 3)\n    }\n    # Collect the trajectories of all traffic participants from the raw data [NumAgent, shift+1, [relative_x, relative_y, relative_theta]]\n    nms_res = {}\n    res = {'token': {}, 'traj': {}, 'token_all': {}}\n    for k, v in data.items():\n        # if k != 'veh':\n        #     continue\n        a_pos = v\n        print(a_pos.shape)\n        # a_pos = a_pos[:, shift:1+shift, :]\n        cal_num = min(int(1e6), a_pos.shape[0])\n        a_pos = a_pos[np.random.choice(a_pos.shape[0], cal_num, replace=False)]\n        a_pos[:, :, -1] = wrap_angle(a_pos[:, :, -1])\n        print(a_pos.shape)\n        if shift <= 2:\n            if k == 'veh':\n                width = 1.0\n                length = 2.4\n            elif k == 'cyc':\n                width = 0.5\n                length = 1.5\n            else:\n                width = 0.5\n                length = 0.5\n        else:\n            if k == 'veh':\n                width = 2.0\n                length = 4.8\n            elif k == 'cyc':\n                width = 1.0\n                length = 2.0\n            else:\n                width = 1.0\n                length = 1.0\n        contour = cal_polygon_contour(a_pos[:, shift, 0], a_pos[:, shift, 1], a_pos[:, shift, 2], width, length)\n\n        # plt.figure(figsize=(10, 10))\n        # for rect in contour:\n        #     rect_closed = np.vstack([rect, rect[0]])\n        #     plt.plot(rect_closed[:, 0], rect_closed[:, 1], linewidth=0.1)\n\n        # plt.title(\"Plot of 256 Rectangles\")\n        # plt.xlabel(\"x\")\n        # plt.ylabel(\"y\")\n        # plt.axis('equal')\n        # plt.savefig(f'src_{k}_new.jpg', dpi=300)\n\n        if k == 'veh':\n            tol = 0.05\n        elif k == 'cyc':\n            tol = 0.004\n        else:\n            tol = 0.004\n        centroids, ret_traj = Kdisk_cluster(contour, num_cluster, tol, width, length, a_pos[:, :shift+1])\n        # plt.figure(figsize=(10, 10))\n        contour = cal_polygon_contour(ret_traj[:, :, 0].reshape(num_cluster*(shift+1)),\n                                      ret_traj[:, :, 1].reshape(num_cluster*(shift+1)),\n                                      ret_traj[:, :, 2].reshape(num_cluster*(shift+1)), width, length)\n\n        res['token_all'][k] = contour.reshape(num_cluster, (shift+1), 4, 2)\n        res['token'][k] = centroids\n        res['traj'][k] = ret_traj\n"
  },
  {
    "path": "smart/__init__.py",
    "content": ""
  },
  {
    "path": "smart/datamodules/__init__.py",
    "content": "from smart.datamodules.scalable_datamodule import MultiDataModule\n"
  },
  {
    "path": "smart/datamodules/scalable_datamodule.py",
    "content": "from typing import Optional\n\nimport pytorch_lightning as pl\nfrom torch_geometric.loader import DataLoader\nfrom smart.datasets.scalable_dataset import MultiDataset\nfrom smart.transforms import WaymoTargetBuilder\n\n\nclass MultiDataModule(pl.LightningDataModule):\n    transforms = {\n        \"WaymoTargetBuilder\": WaymoTargetBuilder,\n    }\n\n    dataset = {\n        \"scalable\": MultiDataset,\n    }\n\n    def __init__(self,\n                 root: str,\n                 train_batch_size: int,\n                 val_batch_size: int,\n                 test_batch_size: int,\n                 shuffle: bool = False,\n                 num_workers: int = 0,\n                 pin_memory: bool = True,\n                 persistent_workers: bool = True,\n                 train_raw_dir: Optional[str] = None,\n                 val_raw_dir: Optional[str] = None,\n                 test_raw_dir: Optional[str] = None,\n                 train_processed_dir: Optional[str] = None,\n                 val_processed_dir: Optional[str] = None,\n                 test_processed_dir: Optional[str] = None,\n                 transform: Optional[str] = None,\n                 dataset: Optional[str] = None,\n                 num_historical_steps: int = 50,\n                 num_future_steps: int = 60,\n                 processor='ntp',\n                 use_intention=False,\n                 token_size=512,\n                 **kwargs) -> None:\n        super(MultiDataModule, self).__init__()\n        self.root = root\n        self.dataset_class = dataset\n        self.train_batch_size = train_batch_size\n        self.val_batch_size = val_batch_size\n        self.test_batch_size = test_batch_size\n        self.shuffle = shuffle\n        self.num_workers = num_workers\n        self.pin_memory = pin_memory\n        self.persistent_workers = persistent_workers and num_workers > 0\n        self.train_raw_dir = train_raw_dir\n        self.val_raw_dir = val_raw_dir\n        self.test_raw_dir = test_raw_dir\n        self.train_processed_dir = train_processed_dir\n        self.val_processed_dir = val_processed_dir\n        self.test_processed_dir = test_processed_dir\n        self.processor = processor\n        self.use_intention = use_intention\n        self.token_size = token_size\n\n        train_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, \"train\")\n        val_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, \"val\")\n        test_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps)\n\n        self.train_transform = train_transform\n        self.val_transform = val_transform\n        self.test_transform = test_transform\n\n    def setup(self, stage: Optional[str] = None) -> None:\n        self.train_dataset = MultiDataModule.dataset[self.dataset_class](self.root, 'train', processed_dir=self.train_processed_dir,\n                                                                         raw_dir=self.train_raw_dir, processor=self.processor, transform=self.train_transform, token_size=self.token_size)\n        self.val_dataset = MultiDataModule.dataset[self.dataset_class](None, 'val', processed_dir=self.val_processed_dir,\n                                                                       raw_dir=self.val_raw_dir, processor=self.processor, transform=self.val_transform, token_size=self.token_size)\n        self.test_dataset = MultiDataModule.dataset[self.dataset_class](None, 'test', processed_dir=self.test_processed_dir,\n                                                                        raw_dir=self.test_raw_dir, processor=self.processor, transform=self.test_transform, token_size=self.token_size)\n\n    def train_dataloader(self):\n        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle,\n                          num_workers=self.num_workers, pin_memory=self.pin_memory,\n                          persistent_workers=self.persistent_workers)\n\n    def val_dataloader(self):\n        return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False,\n                          num_workers=self.num_workers, pin_memory=self.pin_memory,\n                          persistent_workers=self.persistent_workers)\n\n    def test_dataloader(self):\n        return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False,\n                          num_workers=self.num_workers, pin_memory=self.pin_memory,\n                          persistent_workers=self.persistent_workers)\n"
  },
  {
    "path": "smart/datasets/__init__.py",
    "content": "from smart.datasets.scalable_dataset import MultiDataset\n"
  },
  {
    "path": "smart/datasets/preprocess.py",
    "content": "import torch\nimport numpy as np\nfrom scipy.interpolate import interp1d\nfrom scipy.spatial.distance import euclidean\nimport math\nimport pickle\nfrom smart.utils import wrap_angle\nimport os\n\ndef cal_polygon_contour(x, y, theta, width, length):\n    left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)\n    left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)\n    left_front = np.column_stack((left_front_x, left_front_y))\n\n    right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)\n    right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)\n    right_front = np.column_stack((right_front_x, right_front_y))\n\n    right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)\n    right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)\n    right_back = np.column_stack((right_back_x, right_back_y))\n\n    left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)\n    left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)\n    left_back = np.column_stack((left_back_x, left_back_y))\n\n    polygon_contour = np.concatenate(\n        (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)\n\n    return polygon_contour\n\n\ndef interplating_polyline(polylines, heading, distance=0.5, split_distace=5):\n    # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter\n    dist_along_path_list = [[0]]\n    polylines_list = [[polylines[0]]]\n    for i in range(1, polylines.shape[0]):\n        euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2])\n        heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),\n                           abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi))\n        if heading_diff > math.pi / 4 and euclidean_dist > 3:\n            dist_along_path_list.append([0])\n            polylines_list.append([polylines[i]])\n        elif heading_diff > math.pi / 8 and euclidean_dist > 3:\n            dist_along_path_list.append([0])\n            polylines_list.append([polylines[i]])\n        elif heading_diff > 0.1 and euclidean_dist > 3:\n            dist_along_path_list.append([0])\n            polylines_list.append([polylines[i]])\n        elif euclidean_dist > 10:\n            dist_along_path_list.append([0])\n            polylines_list.append([polylines[i]])\n        else:\n            dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)\n            polylines_list[-1].append(polylines[i])\n    # plt.plot(polylines[:, 0], polylines[:, 1])\n    # plt.savefig('tmp.jpg')\n    new_x_list = []\n    new_y_list = []\n    multi_polylines_list = []\n    for idx in range(len(dist_along_path_list)):\n        if len(dist_along_path_list[idx]) < 2:\n            continue\n        dist_along_path = np.array(dist_along_path_list[idx])\n        polylines_cur = np.array(polylines_list[idx])\n        # Create interpolation functions for x and y coordinates\n        fx = interp1d(dist_along_path, polylines_cur[:, 0])\n        fy = interp1d(dist_along_path, polylines_cur[:, 1])\n        # fyaw = interp1d(dist_along_path, heading)\n\n        # Create an array of distances at which to interpolate\n        new_dist_along_path = np.arange(0, dist_along_path[-1], distance)\n        new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]])\n        # Use the interpolation functions to generate new x and y coordinates\n        new_x = fx(new_dist_along_path)\n        new_y = fy(new_dist_along_path)\n        # new_yaw = fyaw(new_dist_along_path)\n        new_x_list.append(new_x)\n        new_y_list.append(new_y)\n\n        # Combine the new x and y coordinates into a single array\n        new_polylines = np.vstack((new_x, new_y)).T\n        polyline_size = int(split_distace / distance)\n        if new_polylines.shape[0] >= (polyline_size + 1):\n            padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size\n            final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1\n        else:\n            padding_size = new_polylines.shape[0]\n            final_index = 0\n        multi_polylines = None\n        new_polylines = torch.from_numpy(new_polylines)\n        new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],\n                                  new_polylines[1:, 0] - new_polylines[:-1, 0])\n        new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]\n        new_polylines = torch.cat([new_polylines, new_heading], -1)\n        if new_polylines.shape[0] >= (polyline_size + 1):\n            multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)\n            multi_polylines = multi_polylines.transpose(1, 2)\n            multi_polylines = multi_polylines[:, ::5, :]\n        if padding_size >= 3:\n            last_polyline = new_polylines[final_index * polyline_size:]\n            last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]\n            if multi_polylines is not None:\n                multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)\n            else:\n                multi_polylines = last_polyline.unsqueeze(0)\n        if multi_polylines is None:\n            continue\n        multi_polylines_list.append(multi_polylines)\n    if len(multi_polylines_list) > 0:\n        multi_polylines_list = torch.cat(multi_polylines_list, dim=0)\n    else:\n        multi_polylines_list = None\n    return multi_polylines_list\n\n\ndef average_distance_vectorized(point_set1, centroids):\n    dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1))\n    return np.mean(dists, axis=2)\n\n\ndef assign_clusters(sub_X, centroids):\n    distances = average_distance_vectorized(sub_X, centroids)\n    return np.argmin(distances, axis=1)\n\n\nclass TokenProcessor:\n\n    def __init__(self, token_size):\n        module_dir = os.path.dirname(os.path.dirname(__file__))\n        self.agent_token_path = os.path.join(module_dir, f'tokens/cluster_frame_5_{token_size}.pkl')\n        self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')\n        self.noise = False\n        self.disturb = False\n        self.shift = 5\n        self.get_trajectory_token()\n        self.training = False\n        self.current_step = 10\n\n    def preprocess(self, data):\n        data = self.tokenize_agent(data)\n        data = self.tokenize_map(data)\n        del data['city']\n        if 'polygon_is_intersection' in data['map_polygon']:\n            del data['map_polygon']['polygon_is_intersection']\n        if 'route_type' in data['map_polygon']:\n            del data['map_polygon']['route_type']\n        return data\n\n    def get_trajectory_token(self):\n        agent_token_data = pickle.load(open(self.agent_token_path, 'rb'))\n        map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))\n        self.trajectory_token = agent_token_data['token']\n        self.trajectory_token_all = agent_token_data['token_all']\n        self.map_token = {'traj_src': map_token_traj['traj_src'], }\n        self.token_last = {}\n        for k, v in self.trajectory_token_all.items():\n            token_last = torch.from_numpy(v[:, -2:]).to(torch.float)\n            diff_xy = token_last[:, 0, 0] - token_last[:, 0, 3]\n            theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])\n            cos, sin = theta.cos(), theta.sin()\n            rot_mat = theta.new_zeros(token_last.shape[0], 2, 2)\n            rot_mat[:, 0, 0] = cos\n            rot_mat[:, 0, 1] = -sin\n            rot_mat[:, 1, 0] = sin\n            rot_mat[:, 1, 1] = cos\n            agent_token = torch.bmm(token_last[:, 1], rot_mat)\n            agent_token -= token_last[:, 0].mean(1)[:, None, :]\n            self.token_last[k] = agent_token.numpy()\n\n    def clean_heading(self, data):\n        heading = data['agent']['heading']\n        valid = data['agent']['valid_mask']\n        pi = torch.tensor(torch.pi)\n        n_vehicles, n_frames = heading.shape\n\n        heading_diff_raw = heading[:, :-1] - heading[:, 1:]\n        heading_diff = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi\n        heading_diff[heading_diff > pi] -= 2 * pi\n        heading_diff[heading_diff < -pi] += 2 * pi\n\n        valid_pairs = valid[:, :-1] & valid[:, 1:]\n\n        for i in range(n_frames - 1):\n            change_needed = (torch.abs(heading_diff[:, i:i + 1]) > 1.0) & valid_pairs[:, i:i + 1]\n\n            heading[:, i + 1][change_needed.squeeze()] = heading[:, i][change_needed.squeeze()]\n\n            if i < n_frames - 2:\n                heading_diff_raw = heading[:, i + 1] - heading[:, i + 2]\n                heading_diff[:, i + 1] = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi\n                heading_diff[heading_diff[:, i + 1] > pi] -= 2 * pi\n                heading_diff[heading_diff[:, i + 1] < -pi] += 2 * pi\n\n    def tokenize_agent(self, data):\n        if data['agent'][\"velocity\"].shape[1] == 90:\n            print(data['scenario_id'], data['agent'][\"velocity\"].shape)\n        interplote_mask = (data['agent']['valid_mask'][:, self.current_step] == False) * (\n                data['agent']['position'][:, self.current_step, 0] != 0)\n        if data['agent'][\"velocity\"].shape[-1] == 2:\n            data['agent'][\"velocity\"] = torch.cat([data['agent'][\"velocity\"],\n                                                   torch.zeros(data['agent'][\"velocity\"].shape[0],\n                                                               data['agent'][\"velocity\"].shape[1], 1)], dim=-1)\n        vel = data['agent'][\"velocity\"][interplote_mask, self.current_step]\n        data['agent']['position'][interplote_mask, self.current_step - 1, :3] = data['agent']['position'][\n                                                                                interplote_mask, self.current_step,\n                                                                                :3] - vel * 0.1\n        data['agent']['valid_mask'][interplote_mask, self.current_step - 1:self.current_step + 1] = True\n        data['agent']['heading'][interplote_mask, self.current_step - 1] = data['agent']['heading'][\n            interplote_mask, self.current_step]\n        data['agent'][\"velocity\"][interplote_mask, self.current_step - 1] = data['agent'][\"velocity\"][\n            interplote_mask, self.current_step]\n\n        data['agent']['type'] = data['agent']['type'].to(torch.uint8)\n\n        self.clean_heading(data)\n        matching_extra_mask = (data['agent']['valid_mask'][:, self.current_step] == True) * (\n                data['agent']['valid_mask'][:, self.current_step - 5] == False)\n\n        interplote_mask_first = (data['agent']['valid_mask'][:, 0] == False) * (data['agent']['position'][:, 0, 0] != 0)\n        data['agent']['valid_mask'][interplote_mask_first, 0] = True\n\n        agent_pos = data['agent']['position'][:, :, :2]\n        valid_mask = data['agent']['valid_mask']\n\n        valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift)\n        token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1]\n        agent_type = data['agent']['type']\n        agent_category = data['agent']['category']\n        agent_heading = data['agent']['heading']\n        vehicle_mask = agent_type == 0\n        cyclist_mask = agent_type == 2\n        ped_mask = agent_type == 1\n\n        veh_pos = agent_pos[vehicle_mask, :, :]\n        veh_valid_mask = valid_mask[vehicle_mask, :]\n        cyc_pos = agent_pos[cyclist_mask, :, :]\n        cyc_valid_mask = valid_mask[cyclist_mask, :]\n        ped_pos = agent_pos[ped_mask, :, :]\n        ped_valid_mask = valid_mask[ped_mask, :]\n\n        veh_token_index, veh_token_contour = self.match_token(veh_pos, veh_valid_mask, agent_heading[vehicle_mask],\n                                                              'veh', agent_category[vehicle_mask],\n                                                              matching_extra_mask[vehicle_mask])\n        ped_token_index, ped_token_contour = self.match_token(ped_pos, ped_valid_mask, agent_heading[ped_mask], 'ped',\n                                                              agent_category[ped_mask], matching_extra_mask[ped_mask])\n        cyc_token_index, cyc_token_contour = self.match_token(cyc_pos, cyc_valid_mask, agent_heading[cyclist_mask],\n                                                              'cyc', agent_category[cyclist_mask],\n                                                              matching_extra_mask[cyclist_mask])\n\n        token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64)\n        token_index[vehicle_mask] = veh_token_index\n        token_index[ped_mask] = ped_token_index\n        token_index[cyclist_mask] = cyc_token_index\n\n        token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1],\n                                     veh_token_contour.shape[2], veh_token_contour.shape[3]))\n        token_contour[vehicle_mask] = veh_token_contour\n        token_contour[ped_mask] = ped_token_contour\n        token_contour[cyclist_mask] = cyc_token_contour\n\n        trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(torch.float)\n        trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(torch.float)\n        trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(torch.float)\n\n        agent_token_traj = torch.zeros((agent_pos.shape[0], trajectory_token_veh.shape[0], 4, 2))\n        agent_token_traj[vehicle_mask] = trajectory_token_veh\n        agent_token_traj[ped_mask] = trajectory_token_ped\n        agent_token_traj[cyclist_mask] = trajectory_token_cyc\n\n        if not self.training:\n            token_valid_mask[matching_extra_mask, 1] = True\n\n        data['agent']['token_idx'] = token_index\n        data['agent']['token_contour'] = token_contour\n        token_pos = token_contour.mean(dim=2)\n        data['agent']['token_pos'] = token_pos\n        diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :]\n        data['agent']['token_heading'] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])\n        data['agent']['agent_valid_mask'] = token_valid_mask\n\n        vel = torch.cat([token_pos.new_zeros(data['agent']['num_nodes'], 1, 2),\n                         ((token_pos[:, 1:] - token_pos[:, :-1]) / (0.1 * self.shift))], dim=1)\n        vel_valid_mask = torch.cat([torch.zeros(token_valid_mask.shape[0], 1, dtype=torch.bool),\n                                    (token_valid_mask * token_valid_mask.roll(shifts=1, dims=1))[:, 1:]], dim=1)\n        vel[~vel_valid_mask] = 0\n        vel[data['agent']['valid_mask'][:, self.current_step], 1] = data['agent']['velocity'][\n                                                                    data['agent']['valid_mask'][:, self.current_step],\n                                                                    self.current_step, :2]\n\n        data['agent']['token_velocity'] = vel\n\n        return data\n\n    def match_token(self, pos, valid_mask, heading, category, agent_category, extra_mask):\n        agent_token_src = self.trajectory_token[category]\n        token_last = self.token_last[category]\n        if self.shift <= 2:\n            if category == 'veh':\n                width = 1.0\n                length = 2.4\n            elif category == 'cyc':\n                width = 0.5\n                length = 1.5\n            else:\n                width = 0.5\n                length = 0.5\n        else:\n            if category == 'veh':\n                width = 2.0\n                length = 4.8\n            elif category == 'cyc':\n                width = 1.0\n                length = 2.0\n            else:\n                width = 1.0\n                length = 1.0\n\n        prev_heading = heading[:, 0]\n        prev_pos = pos[:, 0]\n        agent_num, num_step, feat_dim = pos.shape\n        token_num, token_contour_dim, feat_dim = agent_token_src.shape\n        agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0)\n        token_last = token_last.reshape(1, token_num * token_contour_dim, feat_dim).repeat(extra_mask.sum(), 0)\n        token_index_list = []\n        token_contour_list = []\n        prev_token_idx = None\n\n        for i in range(self.shift, pos.shape[1], self.shift):\n            theta = prev_heading\n            cur_heading = heading[:, i]\n            cur_pos = pos[:, i]\n            cos, sin = theta.cos(), theta.sin()\n            rot_mat = theta.new_zeros(agent_num, 2, 2)\n            rot_mat[:, 0, 0] = cos\n            rot_mat[:, 0, 1] = sin\n            rot_mat[:, 1, 0] = -sin\n            rot_mat[:, 1, 1] = cos\n            agent_token_world = torch.bmm(torch.from_numpy(agent_token_src).to(torch.float), rot_mat).reshape(agent_num,\n                                                                                                              token_num,\n                                                                                                              token_contour_dim,\n                                                                                                              feat_dim)\n            agent_token_world += prev_pos[:, None, None, :]\n\n            cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)\n            agent_token_index = torch.from_numpy(np.argmin(\n                np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),\n                axis=-1))\n            if prev_token_idx is not None and self.noise:\n                same_idx = prev_token_idx == agent_token_index\n                same_idx[:] = True\n                topk_indices = np.argsort(\n                    np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)),\n                            axis=2), axis=-1)[:, :5]\n                sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0])\n                agent_token_index[same_idx] = \\\n                    torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx]\n\n            token_contour_select = agent_token_world[torch.arange(agent_num), agent_token_index]\n\n            diff_xy = token_contour_select[:, 0, :] - token_contour_select[:, 3, :]\n\n            prev_heading = heading[:, i].clone()\n            prev_heading[valid_mask[:, i - self.shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[\n                valid_mask[:, i - self.shift]]\n\n            prev_pos = pos[:, i].clone()\n            prev_pos[valid_mask[:, i - self.shift]] = token_contour_select.mean(dim=1)[valid_mask[:, i - self.shift]]\n            prev_token_idx = agent_token_index\n            token_index_list.append(agent_token_index[:, None])\n            token_contour_list.append(token_contour_select[:, None, ...])\n\n        token_index = torch.cat(token_index_list, dim=1)\n        token_contour = torch.cat(token_contour_list, dim=1)\n\n        # extra matching\n        if not self.training:\n            theta = heading[extra_mask, self.current_step - 1]\n            prev_pos = pos[extra_mask, self.current_step - 1]\n            cur_pos = pos[extra_mask, self.current_step]\n            cur_heading = heading[extra_mask, self.current_step]\n            cos, sin = theta.cos(), theta.sin()\n            rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2)\n            rot_mat[:, 0, 0] = cos\n            rot_mat[:, 0, 1] = sin\n            rot_mat[:, 1, 0] = -sin\n            rot_mat[:, 1, 1] = cos\n            agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape(\n                extra_mask.sum(), token_num, token_contour_dim, feat_dim)\n            agent_token_world += prev_pos[:, None, None, :]\n\n            cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)\n            agent_token_index = torch.from_numpy(np.argmin(\n                np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),\n                axis=-1))\n            token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index]\n\n            token_index[extra_mask, 1] = agent_token_index\n            token_contour[extra_mask, 1] = token_contour_select\n\n        return token_index, token_contour\n\n    def tokenize_map(self, data):\n        data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8)\n        data['map_point']['type'] = data['map_point']['type'].to(torch.uint8)\n        pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index']\n        pt_type = data['map_point']['type'].to(torch.uint8)\n        pt_side = torch.zeros_like(pt_type)\n        pt_pos = data['map_point']['position'][:, :2]\n        data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation'])\n        pt_heading = data['map_point']['orientation']\n        split_polyline_type = []\n        split_polyline_pos = []\n        split_polyline_theta = []\n        split_polyline_side = []\n        pl_idx_list = []\n        split_polygon_type = []\n        data['map_point']['type'].unique()\n\n        for i in sorted(np.unique(pt2pl[1])):\n            index = pt2pl[0, pt2pl[1] == i]\n            polygon_type = data['map_polygon'][\"type\"][i]\n            cur_side = pt_side[index]\n            cur_type = pt_type[index]\n            cur_pos = pt_pos[index]\n            cur_heading = pt_heading[index]\n\n            for side_val in np.unique(cur_side):\n                for type_val in np.unique(cur_type):\n                    if type_val == 13:\n                        continue\n                    indices = np.where((cur_side == side_val) & (cur_type == type_val))[0]\n                    if len(indices) <= 2:\n                        continue\n                    split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy())\n                    if split_polyline is None:\n                        continue\n                    new_cur_type = cur_type[indices][0]\n                    new_cur_side = cur_side[indices][0]\n                    map_polygon_type = polygon_type.repeat(split_polyline.shape[0])\n                    new_cur_type = new_cur_type.repeat(split_polyline.shape[0])\n                    new_cur_side = new_cur_side.repeat(split_polyline.shape[0])\n                    cur_pl_idx = torch.Tensor([i])\n                    new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0])\n                    split_polyline_pos.append(split_polyline[..., :2])\n                    split_polyline_theta.append(split_polyline[..., 2])\n                    split_polyline_type.append(new_cur_type)\n                    split_polyline_side.append(new_cur_side)\n                    pl_idx_list.append(new_cur_pl_idx)\n                    split_polygon_type.append(map_polygon_type)\n\n        split_polyline_pos = torch.cat(split_polyline_pos, dim=0)\n        split_polyline_theta = torch.cat(split_polyline_theta, dim=0)\n        split_polyline_type = torch.cat(split_polyline_type, dim=0)\n        split_polyline_side = torch.cat(split_polyline_side, dim=0)\n        split_polygon_type = torch.cat(split_polygon_type, dim=0)\n        pl_idx_list = torch.cat(pl_idx_list, dim=0)\n        vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :]\n        data['map_save'] = {}\n        data['pt_token'] = {}\n        data['map_save']['traj_pos'] = split_polyline_pos\n        data['map_save']['traj_theta'] = split_polyline_theta[:, 0]  # torch.arctan2(vec[:, 1], vec[:, 0])\n        data['map_save']['pl_idx_list'] = pl_idx_list\n        data['pt_token']['type'] = split_polyline_type\n        data['pt_token']['side'] = split_polyline_side\n        data['pt_token']['pl_type'] = split_polygon_type\n        data['pt_token']['num_nodes'] = split_polyline_pos.shape[0]\n        return data"
  },
  {
    "path": "smart/datasets/scalable_dataset.py",
    "content": "import os\nimport pickle\nfrom typing import Callable, List, Optional, Tuple, Union\nimport pandas as pd\nfrom torch_geometric.data import Dataset\nfrom smart.utils.log import Logging\nimport numpy as np\nfrom .preprocess import TokenProcessor\n\n\ndef distance(point1, point2):\n    return np.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2)\n\n\nclass MultiDataset(Dataset):\n    def __init__(self,\n                 root: str,\n                 split: str,\n                 raw_dir: List[str] = None,\n                 processed_dir: List[str] = None,\n                 transform: Optional[Callable] = None,\n                 dim: int = 3,\n                 num_historical_steps: int = 50,\n                 num_future_steps: int = 60,\n                 predict_unseen_agents: bool = False,\n                 vector_repr: bool = True,\n                 cluster: bool = False,\n                 processor=None,\n                 use_intention=False,\n                 token_size=512) -> None:\n        self.logger = Logging().log(level='DEBUG')\n        self.root = root\n        self.well_done = [0]\n        if split not in ('train', 'val', 'test'):\n            raise ValueError(f'{split} is not a valid split')\n        self.split = split\n        self.training = split == 'train'\n        self.logger.debug(\"Starting loading dataset\")\n        self._raw_file_names = []\n        self._raw_paths = []\n        self._raw_file_dataset = []\n        if raw_dir is not None:\n            self._raw_dir = raw_dir\n            for raw_dir in self._raw_dir:\n                raw_dir = os.path.expanduser(os.path.normpath(raw_dir))\n                dataset = \"waymo\"\n                file_list = os.listdir(raw_dir)\n                self._raw_file_names.extend(file_list)\n                self._raw_paths.extend([os.path.join(raw_dir, f) for f in file_list])\n                self._raw_file_dataset.extend([dataset for _ in range(len(file_list))])\n        if self.root is not None:\n            split_datainfo = os.path.join(root, \"split_datainfo.pkl\")\n            with open(split_datainfo, 'rb+') as f:\n                split_datainfo = pickle.load(f)\n            if split == \"test\":\n                split = \"val\"\n            self._processed_file_names = split_datainfo[split]\n        self.dim = dim\n        self.num_historical_steps = num_historical_steps\n        self._num_samples = len(self._processed_file_names) - 1 if processed_dir is not None else len(self._raw_file_names)\n        self.logger.debug(\"The number of {} dataset is \".format(split) + str(self._num_samples))\n        self.token_processor = TokenProcessor(2048)\n        super(MultiDataset, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None)\n\n    @property\n    def raw_dir(self) -> str:\n        return self._raw_dir\n\n    @property\n    def raw_paths(self) -> List[str]:\n        return self._raw_paths\n\n    @property\n    def raw_file_names(self) -> Union[str, List[str], Tuple]:\n        return self._raw_file_names\n\n    @property\n    def processed_file_names(self) -> Union[str, List[str], Tuple]:\n        return self._processed_file_names\n\n    def len(self) -> int:\n        return self._num_samples\n\n    def generate_ref_token(self):\n        pass\n\n    def get(self, idx: int):\n        with open(self.raw_paths[idx], 'rb') as handle:\n            data = pickle.load(handle)\n        data = self.token_processor.preprocess(data)\n        return data\n"
  },
  {
    "path": "smart/layers/__init__.py",
    "content": "\nfrom smart.layers.attention_layer import AttentionLayer\nfrom smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding\nfrom smart.layers.mlp_layer import MLPLayer\n"
  },
  {
    "path": "smart/layers/attention_layer.py",
    "content": "\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.utils import softmax\n\nfrom smart.utils import weight_init\n\n\nclass AttentionLayer(MessagePassing):\n\n    def __init__(self,\n                 hidden_dim: int,\n                 num_heads: int,\n                 head_dim: int,\n                 dropout: float,\n                 bipartite: bool,\n                 has_pos_emb: bool,\n                 **kwargs) -> None:\n        super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs)\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n        self.has_pos_emb = has_pos_emb\n        self.scale = head_dim ** -0.5\n\n        self.to_q = nn.Linear(hidden_dim, head_dim * num_heads)\n        self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)\n        self.to_v = nn.Linear(hidden_dim, head_dim * num_heads)\n        if has_pos_emb:\n            self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)\n            self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads)\n        self.to_s = nn.Linear(hidden_dim, head_dim * num_heads)\n        self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads)\n        self.to_out = nn.Linear(head_dim * num_heads, hidden_dim)\n        self.attn_drop = nn.Dropout(dropout)\n        self.ff_mlp = nn.Sequential(\n            nn.Linear(hidden_dim, hidden_dim * 4),\n            nn.ReLU(inplace=True),\n            nn.Dropout(dropout),\n            nn.Linear(hidden_dim * 4, hidden_dim),\n        )\n        if bipartite:\n            self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)\n            self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim)\n        else:\n            self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)\n            self.attn_prenorm_x_dst = self.attn_prenorm_x_src\n        if has_pos_emb:\n            self.attn_prenorm_r = nn.LayerNorm(hidden_dim)\n        self.attn_postnorm = nn.LayerNorm(hidden_dim)\n        self.ff_prenorm = nn.LayerNorm(hidden_dim)\n        self.ff_postnorm = nn.LayerNorm(hidden_dim)\n        self.apply(weight_init)\n\n    def forward(self,\n                x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],\n                r: Optional[torch.Tensor],\n                edge_index: torch.Tensor) -> torch.Tensor:\n        if isinstance(x, torch.Tensor):\n            x_src = x_dst = self.attn_prenorm_x_src(x)\n        else:\n            x_src, x_dst = x\n            x_src = self.attn_prenorm_x_src(x_src)\n            x_dst = self.attn_prenorm_x_dst(x_dst)\n            x = x[1]\n        if self.has_pos_emb and r is not None:\n            r = self.attn_prenorm_r(r)\n        x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index))\n        x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x)))\n        return x\n\n    def message(self,\n                q_i: torch.Tensor,\n                k_j: torch.Tensor,\n                v_j: torch.Tensor,\n                r: Optional[torch.Tensor],\n                index: torch.Tensor,\n                ptr: Optional[torch.Tensor]) -> torch.Tensor:\n        if self.has_pos_emb and r is not None:\n            k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim)\n            v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim)\n        sim = (q_i * k_j).sum(dim=-1) * self.scale\n        attn = softmax(sim, index, ptr)\n        self.attention_weight = attn.sum(-1).detach()\n        attn = self.attn_drop(attn)\n        return v_j * attn.unsqueeze(-1)\n\n    def update(self,\n               inputs: torch.Tensor,\n               x_dst: torch.Tensor) -> torch.Tensor:\n        inputs = inputs.view(-1, self.num_heads * self.head_dim)\n        g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1)))\n        return inputs + g * (self.to_s(x_dst) - inputs)\n\n    def _attn_block(self,\n                    x_src: torch.Tensor,\n                    x_dst: torch.Tensor,\n                    r: Optional[torch.Tensor],\n                    edge_index: torch.Tensor) -> torch.Tensor:\n        q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim)\n        k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim)\n        v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim)\n        agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r)\n        return self.to_out(agg)\n\n    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:\n        return self.ff_mlp(x)\n"
  },
  {
    "path": "smart/layers/fourier_embedding.py",
    "content": "import math\nfrom typing import List, Optional\nimport torch\nimport torch.nn as nn\n\nfrom smart.utils import weight_init\n\n\nclass FourierEmbedding(nn.Module):\n\n    def __init__(self,\n                 input_dim: int,\n                 hidden_dim: int,\n                 num_freq_bands: int) -> None:\n        super(FourierEmbedding, self).__init__()\n        self.input_dim = input_dim\n        self.hidden_dim = hidden_dim\n\n        self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None\n        self.mlps = nn.ModuleList(\n            [nn.Sequential(\n                nn.Linear(num_freq_bands * 2 + 1, hidden_dim),\n                nn.LayerNorm(hidden_dim),\n                nn.ReLU(inplace=True),\n                nn.Linear(hidden_dim, hidden_dim),\n            )\n                for _ in range(input_dim)])\n        self.to_out = nn.Sequential(\n            nn.LayerNorm(hidden_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(hidden_dim, hidden_dim),\n        )\n        self.apply(weight_init)\n\n    def forward(self,\n                continuous_inputs: Optional[torch.Tensor] = None,\n                categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:\n        if continuous_inputs is None:\n            if categorical_embs is not None:\n                x = torch.stack(categorical_embs).sum(dim=0)\n            else:\n                raise ValueError('Both continuous_inputs and categorical_embs are None')\n        else:\n            x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi\n            # Warning: if your data are noisy, don't use learnable sinusoidal embedding\n            x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1)\n            continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim\n            for i in range(self.input_dim):\n                continuous_embs[i] = self.mlps[i](x[:, i])\n            x = torch.stack(continuous_embs).sum(dim=0)\n            if categorical_embs is not None:\n                x = x + torch.stack(categorical_embs).sum(dim=0)\n        return self.to_out(x)\n\n\nclass MLPEmbedding(nn.Module):\n    def __init__(self,\n                 input_dim: int,\n                 hidden_dim: int) -> None:\n        super(MLPEmbedding, self).__init__()\n        self.input_dim = input_dim\n        self.hidden_dim = hidden_dim\n        self.mlp = nn.Sequential(\n            nn.Linear(input_dim, 128),\n            nn.LayerNorm(128),\n            nn.ReLU(inplace=True),\n            nn.Linear(128, hidden_dim),\n            nn.LayerNorm(hidden_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(hidden_dim, hidden_dim))\n        self.apply(weight_init)\n\n    def forward(self,\n                continuous_inputs: Optional[torch.Tensor] = None,\n                categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:\n        if continuous_inputs is None:\n            if categorical_embs is not None:\n                x = torch.stack(categorical_embs).sum(dim=0)\n            else:\n                raise ValueError('Both continuous_inputs and categorical_embs are None')\n        else:\n            x = self.mlp(continuous_inputs)\n            if categorical_embs is not None:\n                x = x + torch.stack(categorical_embs).sum(dim=0)\n        return x\n"
  },
  {
    "path": "smart/layers/mlp_layer.py",
    "content": "\nimport torch\nimport torch.nn as nn\n\nfrom smart.utils import weight_init\n\n\nclass MLPLayer(nn.Module):\n\n    def __init__(self,\n                 input_dim: int,\n                 hidden_dim: int,\n                 output_dim: int) -> None:\n        super(MLPLayer, self).__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(input_dim, hidden_dim),\n            nn.LayerNorm(hidden_dim),\n            nn.ReLU(inplace=True),\n            nn.Linear(hidden_dim, output_dim),\n        )\n        self.apply(weight_init)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.mlp(x)\n"
  },
  {
    "path": "smart/metrics/__init__.py",
    "content": "\nfrom smart.metrics.average_meter import AverageMeter\nfrom smart.metrics.min_ade import minADE\nfrom smart.metrics.min_fde import minFDE\nfrom smart.metrics.next_token_cls import TokenCls\n"
  },
  {
    "path": "smart/metrics/average_meter.py",
    "content": "\nimport torch\nfrom torchmetrics import Metric\n\n\nclass AverageMeter(Metric):\n\n    def __init__(self, **kwargs) -> None:\n        super(AverageMeter, self).__init__(**kwargs)\n        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')\n\n    def update(self, val: torch.Tensor) -> None:\n        self.sum += val.sum()\n        self.count += val.numel()\n\n    def compute(self) -> torch.Tensor:\n        return self.sum / self.count\n"
  },
  {
    "path": "smart/metrics/min_ade.py",
    "content": "\nfrom typing import Optional\n\nimport torch\nfrom torchmetrics import Metric\n\nfrom smart.metrics.utils import topk\nfrom smart.metrics.utils import valid_filter\n\n\nclass minMultiADE(Metric):\n\n    def __init__(self,\n                 max_guesses: int = 6,\n                 **kwargs) -> None:\n        super(minMultiADE, self).__init__(**kwargs)\n        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')\n        self.max_guesses = max_guesses\n\n    def update(self,\n               pred: torch.Tensor,\n               target: torch.Tensor,\n               prob: Optional[torch.Tensor] = None,\n               valid_mask: Optional[torch.Tensor] = None,\n               keep_invalid_final_step: bool = True,\n               min_criterion: str = 'FDE') -> None:\n        pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)\n        pred_topk, _ = topk(self.max_guesses, pred, prob)\n        if min_criterion == 'FDE':\n            inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)\n            inds_best = torch.norm(\n                pred_topk[torch.arange(pred.size(0)), :, inds_last] -\n                target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)\n            self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *\n                          valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()\n        elif min_criterion == 'ADE':\n            self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) *\n                          valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()\n        else:\n            raise ValueError('{} is not a valid criterion'.format(min_criterion))\n        self.count += pred.size(0)\n\n    def compute(self) -> torch.Tensor:\n        return self.sum / self.count\n\n\nclass minADE(Metric):\n\n    def __init__(self,\n                 max_guesses: int = 6,\n                 **kwargs) -> None:\n        super(minADE, self).__init__(**kwargs)\n        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')\n        self.max_guesses = max_guesses\n        self.eval_timestep = 70\n\n    def update(self,\n               pred: torch.Tensor,\n               target: torch.Tensor,\n               prob: Optional[torch.Tensor] = None,\n               valid_mask: Optional[torch.Tensor] = None,\n               keep_invalid_final_step: bool = True,\n               min_criterion: str = 'ADE') -> None:\n        # pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)\n        # pred_topk, _ = topk(self.max_guesses, pred, prob)\n        # if min_criterion == 'FDE':\n        #     inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)\n        #     inds_best = torch.norm(\n        #         pred[torch.arange(pred.size(0)), :, inds_last] -\n        #         target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)\n        #     self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *\n        #                   valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()\n        # elif min_criterion == 'ADE':\n        #     self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) *\n        #                   valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()\n        # else:\n        #     raise ValueError('{} is not a valid criterion'.format(min_criterion))\n        eval_timestep = min(self.eval_timestep, pred.shape[1])\n        self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum()\n        self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum()\n\n    def compute(self) -> torch.Tensor:\n        return self.sum / self.count\n"
  },
  {
    "path": "smart/metrics/min_fde.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torchmetrics import Metric\n\nfrom smart.metrics.utils import topk\nfrom smart.metrics.utils import valid_filter\n\n\nclass minMultiFDE(Metric):\n\n    def __init__(self,\n                 max_guesses: int = 6,\n                 **kwargs) -> None:\n        super(minMultiFDE, self).__init__(**kwargs)\n        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')\n        self.max_guesses = max_guesses\n\n    def update(self,\n               pred: torch.Tensor,\n               target: torch.Tensor,\n               prob: Optional[torch.Tensor] = None,\n               valid_mask: Optional[torch.Tensor] = None,\n               keep_invalid_final_step: bool = True) -> None:\n        pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)\n        pred_topk, _ = topk(self.max_guesses, pred, prob)\n        inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)\n        self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] -\n                               target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2),\n                               p=2, dim=-1).min(dim=-1)[0].sum()\n        self.count += pred.size(0)\n\n    def compute(self) -> torch.Tensor:\n        return self.sum / self.count\n\n\nclass minFDE(Metric):\n\n    def __init__(self,\n                 max_guesses: int = 6,\n                 **kwargs) -> None:\n        super(minFDE, self).__init__(**kwargs)\n        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')\n        self.max_guesses = max_guesses\n        self.eval_timestep = 70\n\n    def update(self,\n               pred: torch.Tensor,\n               target: torch.Tensor,\n               prob: Optional[torch.Tensor] = None,\n               valid_mask: Optional[torch.Tensor] = None,\n               keep_invalid_final_step: bool = True) -> None:\n        eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1\n        self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) *\n                      valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum()\n        self.count += valid_mask[:, eval_timestep-1].sum()\n\n    def compute(self) -> torch.Tensor:\n        return self.sum / self.count\n"
  },
  {
    "path": "smart/metrics/next_token_cls.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torchmetrics import Metric\n\nfrom smart.metrics.utils import topk\nfrom smart.metrics.utils import valid_filter\n\n\nclass TokenCls(Metric):\n\n    def __init__(self,\n                 max_guesses: int = 6,\n                 **kwargs) -> None:\n        super(TokenCls, self).__init__(**kwargs)\n        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')\n        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')\n        self.max_guesses = max_guesses\n\n    def update(self,\n               pred: torch.Tensor,\n               target: torch.Tensor,\n               valid_mask: Optional[torch.Tensor] = None) -> None:\n        target = target[..., None]\n        acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask\n        self.sum += acc.sum()\n        self.count += valid_mask.sum()\n\n    def compute(self) -> torch.Tensor:\n        return self.sum / self.count\n"
  },
  {
    "path": "smart/metrics/utils.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch_scatter import gather_csr\nfrom torch_scatter import segment_csr\n\n\ndef topk(\n        max_guesses: int,\n        pred: torch.Tensor,\n        prob: Optional[torch.Tensor] = None,\n        ptr: Optional[torch.Tensor] = None,\n        joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:\n    max_guesses = min(max_guesses, pred.size(1))\n    if max_guesses == pred.size(1):\n        if prob is not None:\n            prob = prob / prob.sum(dim=-1, keepdim=True)\n        else:\n            prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses\n        return pred, prob\n    else:\n        if prob is not None:\n            if joint:\n                if ptr is None:\n                    inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),\n                                           k=max_guesses, dim=-1, largest=True, sorted=True)[1]\n                    inds_topk = inds_topk.repeat(pred.size(0), 1)\n                else:\n                    inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,\n                                                       reduce='mean'),\n                                           k=max_guesses, dim=-1, largest=True, sorted=True)[1]\n                    inds_topk = gather_csr(src=inds_topk, indptr=ptr)\n            else:\n                inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]\n            pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]\n            prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]\n            prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)\n        else:\n            pred_topk = pred[:, :max_guesses]\n            prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses\n        return pred_topk, prob_topk\n\n\ndef topkind(\n        max_guesses: int,\n        pred: torch.Tensor,\n        prob: Optional[torch.Tensor] = None,\n        ptr: Optional[torch.Tensor] = None,\n        joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n    max_guesses = min(max_guesses, pred.size(1))\n    if max_guesses == pred.size(1):\n        if prob is not None:\n            prob = prob / prob.sum(dim=-1, keepdim=True)\n        else:\n            prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses\n        return pred, prob, None\n    else:\n        if prob is not None:\n            if joint:\n                if ptr is None:\n                    inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),\n                                           k=max_guesses, dim=-1, largest=True, sorted=True)[1]\n                    inds_topk = inds_topk.repeat(pred.size(0), 1)\n                else:\n                    inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,\n                                                       reduce='mean'),\n                                           k=max_guesses, dim=-1, largest=True, sorted=True)[1]\n                    inds_topk = gather_csr(src=inds_topk, indptr=ptr)\n            else:\n                inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]\n            pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]\n            prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]\n            prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)\n        else:\n            pred_topk = pred[:, :max_guesses]\n            prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses\n        return pred_topk, prob_topk, inds_topk\n\n\ndef valid_filter(\n        pred: torch.Tensor,\n        target: torch.Tensor,\n        prob: Optional[torch.Tensor] = None,\n        valid_mask: Optional[torch.Tensor] = None,\n        ptr: Optional[torch.Tensor] = None,\n        keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],\n                                                       torch.Tensor, torch.Tensor]:\n    if valid_mask is None:\n        valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool)\n    if keep_invalid_final_step:\n        filter_mask = valid_mask.any(dim=-1)\n    else:\n        filter_mask = valid_mask[:, -1]\n    pred = pred[filter_mask]\n    target = target[filter_mask]\n    if prob is not None:\n        prob = prob[filter_mask]\n    valid_mask = valid_mask[filter_mask]\n    if ptr is not None:\n        num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum')\n        ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,))\n        torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:])\n    else:\n        ptr = target.new_tensor([0, target.size(0)])\n    return pred, target, prob, valid_mask, ptr\n\n\ndef new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):\n    \"\"\"\n\n    Args:\n        pred_trajs (batch_size, num_modes, num_timestamps, 7)\n        pred_scores (batch_size, num_modes):\n        dist_thresh (float):\n        num_ret_modes (int, optional): Defaults to 6.\n\n    Returns:\n        ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)\n        ret_scores (batch_size, num_ret_modes)\n        ret_idxs (batch_size, num_ret_modes)\n    \"\"\"\n    batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape\n    pred_goals = pred_trajs[:, :, -1, :]\n    dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1)\n    nearby_neighbor = dist < dist_thresh\n    pred_scores = nearby_neighbor.sum(dim=-1) / num_modes\n\n    sorted_idxs = pred_scores.argsort(dim=-1, descending=True)\n    bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)\n    sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]\n    sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs]  # (batch_size, num_modes, num_timestamps, 7)\n    sorted_pred_goals = sorted_pred_trajs[:, :, -1, :]  # (batch_size, num_modes, 7)\n\n    dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)\n    point_cover_mask = (dist < dist_thresh)\n\n    point_val = sorted_pred_scores.clone()  # (batch_size, N)\n    point_val_selected = torch.zeros_like(point_val)  # (batch_size, N)\n\n    ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()\n    ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)\n    ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)\n    bs_idxs = torch.arange(batch_size).type_as(ret_idxs)\n\n    for k in range(num_ret_modes):\n        cur_idx = point_val.argmax(dim=-1)  # (batch_size)\n        ret_idxs[:, k] = cur_idx\n\n        new_cover_mask = point_cover_mask[bs_idxs, cur_idx]  # (batch_size, N)\n        point_val = point_val * (~new_cover_mask).float()  # (batch_size, N)\n        point_val_selected[bs_idxs, cur_idx] = -1\n        point_val += point_val_selected\n\n        ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]\n        ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]\n\n    bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)\n\n    ret_idxs = sorted_idxs[bs_idxs, ret_idxs]\n    return ret_trajs, ret_scores, ret_idxs\n\n\ndef batch_nms(pred_trajs, pred_scores,\n              dist_thresh, num_ret_modes=6,\n              mode='static', speed=None):\n    \"\"\"\n\n    Args:\n        pred_trajs (batch_size, num_modes, num_timestamps, 7)\n        pred_scores (batch_size, num_modes):\n        dist_thresh (float):\n        num_ret_modes (int, optional): Defaults to 6.\n\n    Returns:\n        ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)\n        ret_scores (batch_size, num_ret_modes)\n        ret_idxs (batch_size, num_ret_modes)\n    \"\"\"\n    batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape\n\n    sorted_idxs = pred_scores.argsort(dim=-1, descending=True)\n    bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)\n    sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]\n    sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs]  # (batch_size, num_modes, num_timestamps, 7)\n    sorted_pred_goals = sorted_pred_trajs[:, :, -1, :]  # (batch_size, num_modes, 7)\n\n    if mode == \"speed\":\n        scale = torch.ones(batch_size).to(sorted_pred_goals.device)\n        lon_dist_thresh = 4 * scale\n        lat_dist_thresh = 0.5 * scale\n        lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1)\n        lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1)\n        point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None])\n    else:\n        dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)\n        point_cover_mask = (dist < dist_thresh)\n\n    point_val = sorted_pred_scores.clone()  # (batch_size, N)\n    point_val_selected = torch.zeros_like(point_val)  # (batch_size, N)\n\n    ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()\n    ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)\n    ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)\n    bs_idxs = torch.arange(batch_size).type_as(ret_idxs)\n\n    for k in range(num_ret_modes):\n        cur_idx = point_val.argmax(dim=-1)  # (batch_size)\n        ret_idxs[:, k] = cur_idx\n\n        new_cover_mask = point_cover_mask[bs_idxs, cur_idx]  # (batch_size, N)\n        point_val = point_val * (~new_cover_mask).float()  # (batch_size, N)\n        point_val_selected[bs_idxs, cur_idx] = -1\n        point_val += point_val_selected\n\n        ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]\n        ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]\n\n    bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)\n\n    ret_idxs = sorted_idxs[bs_idxs, ret_idxs]\n    return ret_trajs, ret_scores, ret_idxs\n\n\ndef batch_nms_token(pred_trajs, pred_scores,\n                    dist_thresh, num_ret_modes=6,\n                    mode='static', speed=None):\n    \"\"\"\n    Args:\n        pred_trajs (batch_size, num_modes, num_timestamps, 7)\n        pred_scores (batch_size, num_modes):\n        dist_thresh (float):\n        num_ret_modes (int, optional): Defaults to 6.\n\n    Returns:\n        ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)\n        ret_scores (batch_size, num_ret_modes)\n        ret_idxs (batch_size, num_ret_modes)\n    \"\"\"\n    batch_size, num_modes, num_feat_dim = pred_trajs.shape\n\n    sorted_idxs = pred_scores.argsort(dim=-1, descending=True)\n    bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)\n    sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]\n    sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs]  # (batch_size, num_modes, num_timestamps, 7)\n\n    if mode == \"nearby\":\n        dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)\n        values, indices = torch.topk(dist, 5, dim=-1, largest=False)\n        thresh_hold = values[..., -1]\n        point_cover_mask = dist < thresh_hold[..., None]\n    else:\n        dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)\n        point_cover_mask = (dist < dist_thresh)\n\n    point_val = sorted_pred_scores.clone()  # (batch_size, N)\n    point_val_selected = torch.zeros_like(point_val)  # (batch_size, N)\n\n    ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()\n    ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim)\n    ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes)\n    bs_idxs = torch.arange(batch_size).type_as(ret_idxs)\n\n    for k in range(num_ret_modes):\n        cur_idx = point_val.argmax(dim=-1)  # (batch_size)\n        ret_idxs[:, k] = cur_idx\n\n        new_cover_mask = point_cover_mask[bs_idxs, cur_idx]  # (batch_size, N)\n        point_val = point_val * (~new_cover_mask).float()  # (batch_size, N)\n        point_val_selected[bs_idxs, cur_idx] = -1\n        point_val += point_val_selected\n\n        ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx]\n        ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]\n\n    bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)\n\n    ret_idxs = sorted_idxs[bs_idxs, ret_idxs]\n    return ret_goals, ret_scores, ret_idxs\n"
  },
  {
    "path": "smart/model/__init__.py",
    "content": "from smart.model.smart import SMART\n"
  },
  {
    "path": "smart/model/smart.py",
    "content": "import contextlib\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.data import Batch\nfrom torch_geometric.data import HeteroData\nfrom smart.metrics import minADE\nfrom smart.metrics import minFDE\nfrom smart.metrics import TokenCls\nfrom smart.modules import SMARTDecoder\nfrom torch.optim.lr_scheduler import LambdaLR\nimport math\nimport numpy as np\nimport pickle\nfrom collections import defaultdict\nimport os\nfrom waymo_open_dataset.protos import sim_agents_submission_pb2\n\n\ndef cal_polygon_contour(x, y, theta, width, length):\n    left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)\n    left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)\n    left_front = (left_front_x, left_front_y)\n\n    right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)\n    right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)\n    right_front = (right_front_x, right_front_y)\n\n    right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)\n    right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)\n    right_back = (right_back_x, right_back_y)\n\n    left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)\n    left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)\n    left_back = (left_back_x, left_back_y)\n    polygon_contour = [left_front, right_front, right_back, left_back]\n\n    return polygon_contour\n\n\ndef joint_scene_from_states(states, object_ids) -> sim_agents_submission_pb2.JointScene:\n    states = states.numpy()\n    simulated_trajectories = []\n    for i_object in range(len(object_ids)):\n        simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory(\n            center_x=states[i_object, :, 0], center_y=states[i_object, :, 1],\n            center_z=states[i_object, :, 2], heading=states[i_object, :, 3],\n            object_id=object_ids[i_object].item()\n        ))\n    return sim_agents_submission_pb2.JointScene(simulated_trajectories=simulated_trajectories)\n\n\nclass SMART(pl.LightningModule):\n\n    def __init__(self, model_config) -> None:\n        super(SMART, self).__init__()\n        self.save_hyperparameters()\n        self.model_config = model_config\n        self.warmup_steps = model_config.warmup_steps\n        self.lr = model_config.lr\n        self.total_steps = model_config.total_steps\n        self.dataset = model_config.dataset\n        self.input_dim = model_config.input_dim\n        self.hidden_dim = model_config.hidden_dim\n        self.output_dim = model_config.output_dim\n        self.output_head = model_config.output_head\n        self.num_historical_steps = model_config.num_historical_steps\n        self.num_future_steps = model_config.decoder.num_future_steps\n        self.num_freq_bands = model_config.num_freq_bands\n        self.vis_map = False\n        self.noise = True\n        module_dir = os.path.dirname(os.path.dirname(__file__))\n        self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')\n        self.init_map_token()\n        self.token_path = os.path.join(module_dir, 'tokens/cluster_frame_5_2048.pkl')\n        token_data = self.get_trajectory_token()\n        self.encoder = SMARTDecoder(\n            dataset=model_config.dataset,\n            input_dim=model_config.input_dim,\n            hidden_dim=model_config.hidden_dim,\n            num_historical_steps=model_config.num_historical_steps,\n            num_freq_bands=model_config.num_freq_bands,\n            num_heads=model_config.num_heads,\n            head_dim=model_config.head_dim,\n            dropout=model_config.dropout,\n            num_map_layers=model_config.decoder.num_map_layers,\n            num_agent_layers=model_config.decoder.num_agent_layers,\n            pl2pl_radius=model_config.decoder.pl2pl_radius,\n            pl2a_radius=model_config.decoder.pl2a_radius,\n            a2a_radius=model_config.decoder.a2a_radius,\n            time_span=model_config.decoder.time_span,\n            map_token={'traj_src': self.map_token['traj_src']},\n            token_data=token_data,\n            token_size=model_config.decoder.token_size\n        )\n        self.minADE = minADE(max_guesses=1)\n        self.minFDE = minFDE(max_guesses=1)\n        self.TokenCls = TokenCls(max_guesses=1)\n\n        self.test_predictions = dict()\n        self.cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)\n        self.map_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)\n        self.inference_token = False\n        self.rollout_num = 1\n\n    def get_trajectory_token(self):\n        token_data = pickle.load(open(self.token_path, 'rb'))\n        self.trajectory_token = token_data['token']\n        self.trajectory_token_traj = token_data['traj']\n        self.trajectory_token_all = token_data['token_all']\n        return token_data\n\n    def init_map_token(self):\n        self.argmin_sample_len = 3\n        map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))\n        self.map_token = {'traj_src': map_token_traj['traj_src'], }\n        traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1],\n                                    self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0])\n        indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long()\n        self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float)\n        self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)\n        self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float)\n\n    def forward(self, data: HeteroData):\n        res = self.encoder(data)\n        return res\n\n    def inference(self, data: HeteroData):\n        res = self.encoder.inference(data)\n        return res\n\n    def maybe_autocast(self, dtype=torch.float16):\n        enable_autocast = self.device != torch.device(\"cpu\")\n\n        if enable_autocast:\n            return torch.cuda.amp.autocast(dtype=dtype)\n        else:\n            return contextlib.nullcontext()\n\n    def training_step(self,\n                      data,\n                      batch_idx):\n        data = self.match_token_map(data)\n        data = self.sample_pt_pred(data)\n        if isinstance(data, Batch):\n            data['agent']['av_index'] += data['agent']['ptr'][:-1]\n        pred = self(data)\n        next_token_prob = pred['next_token_prob']\n        next_token_idx_gt = pred['next_token_idx_gt']\n        next_token_eval_mask = pred['next_token_eval_mask']\n        cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])\n        loss = cls_loss\n        self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)\n        self.log('cls_loss', cls_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)\n        return loss\n\n    def validation_step(self,\n                        data,\n                        batch_idx):\n        data = self.match_token_map(data)\n        data = self.sample_pt_pred(data)\n        if isinstance(data, Batch):\n            data['agent']['av_index'] += data['agent']['ptr'][:-1]\n        pred = self(data)\n        next_token_idx = pred['next_token_idx']\n        next_token_idx_gt = pred['next_token_idx_gt']\n        next_token_eval_mask = pred['next_token_eval_mask']\n        next_token_prob = pred['next_token_prob']\n        cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])\n        loss = cls_loss\n        self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],\n                        valid_mask=next_token_eval_mask[next_token_eval_mask])\n        self.log('val_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)\n        self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)\n\n        eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1]  # * (data['agent']['category'] == 3)\n        if self.inference_token:\n            pred = self.inference(data)\n            pos_a = pred['pos_a']\n            gt = pred['gt']\n            valid_mask = data['agent']['valid_mask'][:, self.num_historical_steps:]\n            pred_traj = pred['pred_traj']\n            # next_token_idx = pred['next_token_idx'][..., None]\n            # next_token_idx_gt = pred['next_token_idx_gt'][:, 2:]\n            # next_token_eval_mask = pred['next_token_eval_mask'][:, 2:]\n            # next_token_eval_mask[:, 1:] = False\n            # self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],\n            #                      valid_mask=next_token_eval_mask[next_token_eval_mask])\n            # self.log('val_inference_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)\n            eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1]\n\n            self.minADE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])\n            self.minFDE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])\n            # print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute())\n\n            self.log('val_minADE', self.minADE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)\n            self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)\n\n    def on_validation_start(self):\n        self.gt = []\n        self.pred = []\n        self.scenario_rollouts = []\n        self.batch_metric = defaultdict(list)\n\n    def configure_optimizers(self):\n        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)\n\n        def lr_lambda(current_step):\n            if current_step + 1 < self.warmup_steps:\n                return float(current_step + 1) / float(max(1, self.warmup_steps))\n            return max(\n                0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))))\n            )\n\n        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)\n        return [optimizer], [lr_scheduler]\n\n    def load_params_from_file(self, filename, logger, to_cpu=False):\n        if not os.path.isfile(filename):\n            raise FileNotFoundError\n\n        logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))\n        loc_type = torch.device('cpu') if to_cpu else None\n        checkpoint = torch.load(filename, map_location=loc_type)\n        model_state_disk = checkpoint['state_dict']\n\n        version = checkpoint.get(\"version\", None)\n        if version is not None:\n            logger.info('==> Checkpoint trained from version: %s' % version)\n\n        logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}')\n        model_state = self.state_dict()\n        model_state_disk_filter = {}\n        for key, val in model_state_disk.items():\n            if key in model_state and model_state_disk[key].shape == model_state[key].shape:\n                model_state_disk_filter[key] = val\n            else:\n                if key not in model_state:\n                    print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}')\n                else:\n                    print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}')\n\n        model_state_disk = model_state_disk_filter\n\n        missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False)\n\n        logger.info(f'Missing keys: {missing_keys}')\n        logger.info(f'The number of missing keys: {len(missing_keys)}')\n        logger.info(f'The number of unexpected keys: {len(unexpected_keys)}')\n        logger.info('==> Done (total keys %d)' % (len(model_state)))\n\n        epoch = checkpoint.get('epoch', -1)\n        it = checkpoint.get('it', 0.0)\n\n        return it, epoch\n\n    def match_token_map(self, data):\n        traj_pos = data['map_save']['traj_pos'].to(torch.float)\n        traj_theta = data['map_save']['traj_theta'].to(torch.float)\n        pl_idx_list = data['map_save']['pl_idx_list']\n        token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device)\n        token_src = self.map_token['traj_src'].to(traj_pos.device)\n        max_traj_len = self.map_token['traj_src'].shape[1]\n        pl_num = traj_pos.shape[0]\n\n        pt_token_pos = traj_pos[:, 0, :].clone()\n        pt_token_orientation = traj_theta.clone()\n        cos, sin = traj_theta.cos(), traj_theta.sin()\n        rot_mat = traj_theta.new_zeros(pl_num, 2, 2)\n        rot_mat[..., 0, 0] = cos\n        rot_mat[..., 0, 1] = -sin\n        rot_mat[..., 1, 0] = sin\n        rot_mat[..., 1, 1] = cos\n        traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))\n        distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1))\n        pt_token_id = torch.argmin(distance, dim=1)\n\n        if self.noise:\n            topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]\n            sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)\n            pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)\n\n        cos, sin = traj_theta.cos(), traj_theta.sin()\n        rot_mat = traj_theta.new_zeros(pl_num, 2, 2)\n        rot_mat[..., 0, 0] = cos\n        rot_mat[..., 0, 1] = sin\n        rot_mat[..., 1, 0] = -sin\n        rot_mat[..., 1, 1] = cos\n        token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),\n                                    rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]\n        token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)\n\n        pl_idx_full = pl_idx_list.clone()\n        token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])\n        count_nums = []\n        for pl in pl_idx_full.unique():\n            pt = token2pl[0, token2pl[1, :] == pl]\n            left_side = (data['pt_token']['side'][pt] == 0).sum()\n            right_side = (data['pt_token']['side'][pt] == 1).sum()\n            center_side = (data['pt_token']['side'][pt] == 2).sum()\n            count_nums.append(torch.Tensor([left_side, right_side, center_side]))\n        count_nums = torch.stack(count_nums, dim=0)\n        num_polyline = int(count_nums.max().item())\n        traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)\n        idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)\n        idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)  #\n        counts_num_expanded = count_nums.unsqueeze(-1)\n        mask_update = idx_matrix < counts_num_expanded\n        traj_mask[mask_update] = True\n\n        data['pt_token']['traj_mask'] = traj_mask\n        data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),\n                                                                            device=traj_pos.device, dtype=torch.float)], dim=-1)\n        data['pt_token']['orientation'] = pt_token_orientation\n        data['pt_token']['height'] = data['pt_token']['position'][:, -1]\n        data[('pt_token', 'to', 'map_polygon')] = {}\n        data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl\n        data['pt_token']['token_idx'] = pt_token_id\n        return data\n\n    def sample_pt_pred(self, data):\n        traj_mask = data['pt_token']['traj_mask']\n        raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)\n        masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)]\n        masked_pt_index = torch.sort(masked_pt_index, -1)[0]\n        pt_valid_mask = traj_mask.clone()\n        pt_valid_mask.scatter_(2, masked_pt_index, False)\n        pt_pred_mask = traj_mask.clone()\n        pt_pred_mask.scatter_(2, masked_pt_index, False)\n        tmp_mask = pt_pred_mask.clone()\n        tmp_mask[:, :, :] = True\n        tmp_mask.scatter_(2, masked_pt_index-1, False)\n        pt_pred_mask.masked_fill_(tmp_mask, False)\n        pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)\n        pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)\n\n        data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]\n        data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]\n        data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]\n\n        return data\n"
  },
  {
    "path": "smart/modules/__init__.py",
    "content": "from smart.modules.smart_decoder import SMARTDecoder\nfrom smart.modules.map_decoder import SMARTMapDecoder\nfrom smart.modules.agent_decoder import SMARTAgentDecoder\n"
  },
  {
    "path": "smart/modules/agent_decoder.py",
    "content": "import pickle\nfrom typing import Dict, Mapping, Optional\nimport torch\nimport torch.nn as nn\nfrom smart.layers import MLPLayer\nfrom smart.layers.attention_layer import AttentionLayer\nfrom smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding\nfrom torch_cluster import radius, radius_graph\nfrom torch_geometric.data import Batch, HeteroData\nfrom torch_geometric.utils import dense_to_sparse, subgraph\nfrom smart.utils import angle_between_2d_vectors, weight_init, wrap_angle\nimport math\n\n\ndef cal_polygon_contour(x, y, theta, width, length):\n    left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)\n    left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)\n    left_front = (left_front_x, left_front_y)\n\n    right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)\n    right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)\n    right_front = (right_front_x, right_front_y)\n\n    right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)\n    right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)\n    right_back = (right_back_x, right_back_y)\n\n    left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)\n    left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)\n    left_back = (left_back_x, left_back_y)\n    polygon_contour = [left_front, right_front, right_back, left_back]\n\n    return polygon_contour\n\n\nclass SMARTAgentDecoder(nn.Module):\n\n    def __init__(self,\n                 dataset: str,\n                 input_dim: int,\n                 hidden_dim: int,\n                 num_historical_steps: int,\n                 time_span: Optional[int],\n                 pl2a_radius: float,\n                 a2a_radius: float,\n                 num_freq_bands: int,\n                 num_layers: int,\n                 num_heads: int,\n                 head_dim: int,\n                 dropout: float,\n                 token_data: Dict,\n                 token_size=512) -> None:\n        super(SMARTAgentDecoder, self).__init__()\n        self.dataset = dataset\n        self.input_dim = input_dim\n        self.hidden_dim = hidden_dim\n        self.num_historical_steps = num_historical_steps\n        self.time_span = time_span if time_span is not None else num_historical_steps\n        self.pl2a_radius = pl2a_radius\n        self.a2a_radius = a2a_radius\n        self.num_freq_bands = num_freq_bands\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n        self.dropout = dropout\n\n        input_dim_x_a = 2\n        input_dim_r_t = 4\n        input_dim_r_pt2a = 3\n        input_dim_r_a2a = 3\n        input_dim_token = 8\n\n        self.type_a_emb = nn.Embedding(4, hidden_dim)\n        self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim)\n\n        self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)\n        self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)\n        self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,\n                                           num_freq_bands=num_freq_bands)\n        self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,\n                                          num_freq_bands=num_freq_bands)\n        self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)\n        self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)\n        self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)\n        self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim)\n\n        self.t_attn_layers = nn.ModuleList(\n            [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,\n                            bipartite=False, has_pos_emb=True) for _ in range(num_layers)]\n        )\n        self.pt2a_attn_layers = nn.ModuleList(\n            [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,\n                            bipartite=True, has_pos_emb=True) for _ in range(num_layers)]\n        )\n        self.a2a_attn_layers = nn.ModuleList(\n            [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,\n                            bipartite=False, has_pos_emb=True) for _ in range(num_layers)]\n        )\n        self.token_size = token_size\n        self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,\n                                           output_dim=self.token_size)\n        self.trajectory_token = token_data['token']\n        self.trajectory_token_traj = token_data['traj']\n        self.trajectory_token_all = token_data['token_all']\n        self.apply(weight_init)\n        self.shift = 5\n        self.beam_size = 5\n        self.hist_mask = True\n\n    def transform_rel(self, token_traj, prev_pos, prev_heading=None):\n        if prev_heading is None:\n            diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]\n            prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])\n\n        num_agent, num_step, traj_num, traj_dim = token_traj.shape\n        cos, sin = prev_heading.cos(), prev_heading.sin()\n        rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)\n        rot_mat[:, :, 0, 0] = cos\n        rot_mat[:, :, 0, 1] = -sin\n        rot_mat[:, :, 1, 0] = sin\n        rot_mat[:, :, 1, 1] = cos\n        agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)\n        agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]\n        return agent_pred_rel\n\n    def agent_token_embedding(self, data, agent_category, agent_token_index, pos_a, head_vector_a, inference=False):\n        num_agent, num_step, traj_dim = pos_a.shape\n        motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),\n                                     pos_a[:, 1:] - pos_a[:, :-1]], dim=1)\n\n        agent_type = data['agent']['type']\n        veh_mask = (agent_type == 0)\n        cyc_mask = (agent_type == 2)\n        ped_mask = (agent_type == 1)\n        trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)\n        self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1))\n        trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)\n        self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))\n        trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)\n        self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))\n\n        if inference:\n            agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)\n            trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(\n                torch.float)\n            trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(\n                torch.float)\n            trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(\n                torch.float)\n            agent_token_traj_all[veh_mask] = torch.cat(\n                [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)\n            agent_token_traj_all[ped_mask] = torch.cat(\n                [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)\n            agent_token_traj_all[cyc_mask] = torch.cat(\n                [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)\n\n        agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)\n        agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]\n        agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]\n        agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]\n\n        agent_token_traj = torch.zeros((num_agent, num_step, self.token_size, 4, 2), device=pos_a.device)\n        agent_token_traj[veh_mask] = trajectory_token_veh\n        agent_token_traj[ped_mask] = trajectory_token_ped\n        agent_token_traj[cyc_mask] = trajectory_token_cyc\n\n        vel = data['agent']['token_velocity']\n\n        categorical_embs = [\n            self.type_a_emb(data['agent']['type'].long()).repeat_interleave(repeats=num_step,\n                                                                            dim=0),\n\n            self.shape_emb(data['agent']['shape'][:, self.num_historical_steps - 1, :]).repeat_interleave(\n                repeats=num_step,\n                dim=0)\n        ]\n        feature_a = torch.stack(\n            [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),\n             angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),\n             ], dim=-1)\n\n        x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),\n                           categorical_embs=categorical_embs)\n        x_a = x_a.view(-1, num_step, self.hidden_dim)\n\n        feat_a = torch.cat((agent_token_emb, x_a), dim=-1)\n        feat_a = self.fusion_emb(feat_a)\n\n        if inference:\n            return feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs\n        else:\n            return feat_a, agent_token_traj\n\n    def agent_predict_next(self, data, agent_category, feat_a):\n        num_agent, num_step, traj_dim = data['agent']['token_pos'].shape\n        agent_type = data['agent']['type']\n        veh_mask = (agent_type == 0)  # * agent_category==3\n        cyc_mask = (agent_type == 2)  # * agent_category==3\n        ped_mask = (agent_type == 1)  # * agent_category==3\n        token_res = torch.zeros((num_agent, num_step, self.token_size), device=agent_category.device)\n        token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])\n        token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])\n        token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])\n        return token_res\n\n    def agent_predict_next_inf(self, data, agent_category, feat_a):\n        num_agent, traj_dim = feat_a.shape\n        agent_type = data['agent']['type']\n\n        veh_mask = (agent_type == 0)  # * agent_category==3\n        cyc_mask = (agent_type == 2)  # * agent_category==3\n        ped_mask = (agent_type == 1)  # * agent_category==3\n\n        token_res = torch.zeros((num_agent, self.token_size), device=agent_category.device)\n        token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])\n        token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])\n        token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])\n\n        return token_res\n\n    def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent, mask, inference_mask=None):\n        pos_t = pos_a.reshape(-1, self.input_dim)\n        head_t = head_a.reshape(-1)\n        head_vector_t = head_vector_a.reshape(-1, 2)\n        hist_mask = mask.clone()\n\n        if self.hist_mask and self.training:\n            hist_mask[\n                torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False\n            mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)\n        elif inference_mask is not None:\n            mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)\n        else:\n            mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)\n\n        edge_index_t = dense_to_sparse(mask_t)[0]\n        edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]]\n        edge_index_t = edge_index_t[:, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift]\n        rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]\n        rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])\n        r_t = torch.stack(\n            [torch.norm(rel_pos_t[:, :2], p=2, dim=-1),\n             angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),\n             rel_head_t,\n             edge_index_t[0] - edge_index_t[1]], dim=-1)\n        r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)\n        return edge_index_t, r_t\n\n    def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s, mask_s):\n        pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)\n        head_s = head_a.transpose(0, 1).reshape(-1)\n        head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)\n        edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,\n                                      max_num_neighbors=300)\n        edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0]\n        rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]\n        rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])\n        r_a2a = torch.stack(\n            [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),\n             angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),\n             rel_head_a2a], dim=-1)\n        r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)\n        return edge_index_a2a, r_a2a\n\n    def build_map2agent_edge(self, data, num_step, agent_category, pos_a, head_a, head_vector_a, mask,\n                             batch_s, batch_pl):\n        mask_pl2a = mask.clone()\n        mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)\n        pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)\n        head_s = head_a.transpose(0, 1).reshape(-1)\n        head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)\n        pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()\n        orient_pl = data['pt_token']['orientation'].contiguous()\n        pos_pl = pos_pl.repeat(num_step, 1)\n        orient_pl = orient_pl.repeat(num_step)\n        edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,\n                                 batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)\n        edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]]\n        rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]\n        rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])\n        r_pl2a = torch.stack(\n            [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),\n             angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),\n             rel_orient_pl2a], dim=-1)\n        r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)\n        return edge_index_pl2a, r_pl2a\n\n    def forward(self,\n                data: HeteroData,\n                map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n        pos_a = data['agent']['token_pos']\n        head_a = data['agent']['token_heading']\n        head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)\n        num_agent, num_step, traj_dim = pos_a.shape\n        agent_category = data['agent']['category']\n        agent_token_index = data['agent']['token_idx']\n        feat_a, agent_token_traj = self.agent_token_embedding(data, agent_category, agent_token_index,\n                                                              pos_a, head_vector_a)\n\n        agent_valid_mask = data['agent']['agent_valid_mask'].clone()\n        # eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]\n        # agent_valid_mask[~eval_mask] = False\n        mask = agent_valid_mask\n        edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask)\n\n        if isinstance(data, Batch):\n            batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t\n                                 for t in range(num_step)], dim=0)\n            batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t\n                                  for t in range(num_step)], dim=0)\n        else:\n            batch_s = torch.arange(num_step,\n                                   device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])\n            batch_pl = torch.arange(num_step,\n                                    device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])\n\n        mask_s = mask.transpose(0, 1).reshape(-1)\n        edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, batch_s, mask_s)\n        mask[agent_category != 3] = False\n        edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,\n                                                            head_vector_a, mask, batch_s, batch_pl)\n\n        for i in range(self.num_layers):\n            feat_a = feat_a.reshape(-1, self.hidden_dim)\n            feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)\n            feat_a = feat_a.reshape(-1, num_step,\n                                    self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)\n            feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(\n                repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(\n                    -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)\n            feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)\n            feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)\n\n        num_agent, num_step, hidden_dim, traj_num, traj_dim = agent_token_traj.shape\n        next_token_prob = self.token_predict_head(feat_a)\n        next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)\n        _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)\n\n        next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)\n        next_token_eval_mask = mask.clone()\n        next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)\n        next_token_eval_mask[:, -1] = False\n\n        return {'x_a': feat_a,\n                'next_token_idx': next_token_idx,\n                'next_token_prob': next_token_prob,\n                'next_token_idx_gt': next_token_index_gt,\n                'next_token_eval_mask': next_token_eval_mask,\n                }\n\n    def inference(self,\n                  data: HeteroData,\n                  map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n        eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]\n        pos_a = data['agent']['token_pos'].clone()\n        head_a = data['agent']['token_heading'].clone()\n        num_agent, num_step, traj_dim = pos_a.shape\n        pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0\n        head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0\n        head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)\n\n        agent_valid_mask = data['agent']['agent_valid_mask'].clone()\n        agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True\n        agent_valid_mask[~eval_mask] = False\n        agent_token_index = data['agent']['token_idx']\n        agent_category = data['agent']['category']\n        feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding(\n            data,\n            agent_category,\n            agent_token_index,\n            pos_a,\n            head_vector_a,\n            inference=True)\n\n        agent_type = data[\"agent\"][\"type\"]\n        veh_mask = (agent_type == 0)  # * agent_category==3\n        cyc_mask = (agent_type == 2)  # * agent_category==3\n        ped_mask = (agent_type == 1)  # * agent_category==3\n        av_mask = data[\"agent\"][\"av_index\"]\n\n        self.num_recurrent_steps_val = data[\"agent\"]['position'].shape[1]-self.num_historical_steps\n        pred_traj = torch.zeros(data[\"agent\"].num_nodes, self.num_recurrent_steps_val, 2, device=feat_a.device)\n        pred_head = torch.zeros(data[\"agent\"].num_nodes, self.num_recurrent_steps_val, device=feat_a.device)\n        pred_prob = torch.zeros(data[\"agent\"].num_nodes, self.num_recurrent_steps_val // self.shift, device=feat_a.device)\n        next_token_idx_list = []\n        mask = agent_valid_mask.clone()\n        feat_a_t_dict = {}\n        for t in range(self.num_recurrent_steps_val // self.shift):\n            if t == 0:\n                inference_mask = mask.clone()\n                inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False\n            else:\n                inference_mask = torch.zeros_like(mask)\n                inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True\n            edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask, inference_mask)\n            if isinstance(data, Batch):\n                batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t\n                                     for t in range(num_step)], dim=0)\n                batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t\n                                      for t in range(num_step)], dim=0)\n            else:\n                batch_s = torch.arange(num_step,\n                                       device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])\n                batch_pl = torch.arange(num_step,\n                                        device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])\n            # In the inference stage, we only infer the current stage for recurrent\n            edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,\n                                                                head_vector_a,\n                                                                inference_mask, batch_s,\n                                                                batch_pl)\n            mask_s = inference_mask.transpose(0, 1).reshape(-1)\n            edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a,\n                                                                batch_s, mask_s)\n\n            for i in range(self.num_layers):\n                if i in feat_a_t_dict:\n                    feat_a = feat_a_t_dict[i]\n                feat_a = feat_a.reshape(-1, self.hidden_dim)\n                feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)\n                feat_a = feat_a.reshape(-1, num_step,\n                                        self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)\n                feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(\n                    repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(\n                        -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)\n                feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)\n                feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)\n\n                if i+1 not in feat_a_t_dict:\n                    feat_a_t_dict[i+1] = feat_a\n                else:\n                    feat_a_t_dict[i+1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]\n\n            next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])\n\n            next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)\n\n            topk_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1)\n\n            expanded_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)\n            next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_index)\n\n            theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]\n            cos, sin = theta.cos(), theta.sin()\n            rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device)\n            rot_mat[:, 0, 0] = cos\n            rot_mat[:, 0, 1] = sin\n            rot_mat[:, 1, 0] = -sin\n            rot_mat[:, 1, 1] = cos\n            agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2),\n                                       rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view(\n                                           -1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2)\n            agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...]\n\n            sample_index = torch.multinomial(topk_prob, 1).to(agent_pred_rel.device)\n            agent_pred_rel = agent_pred_rel.gather(dim=1,\n                                                   index=sample_index[..., None, None, None].expand(-1, -1, 6, 4,\n                                                                                                    2))[:, 0, ...]\n            pred_prob[:, t] = topk_prob.gather(dim=-1, index=sample_index)[:, 0]\n            pred_traj[:, t * 5:(t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2)\n            diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :]\n            pred_head[:, t * 5:(t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])\n\n            pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1)\n            diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :]\n            theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])\n            head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta\n            next_token_idx = next_token_idx.gather(dim=1, index=sample_index)\n            next_token_idx = next_token_idx.squeeze(-1)\n            next_token_idx_list.append(next_token_idx[:, None])\n            agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[\n                next_token_idx[veh_mask]]\n            agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[\n                next_token_idx[ped_mask]]\n            agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[\n                next_token_idx[cyc_mask]]\n            motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),\n                                         pos_a[:, 1:] - pos_a[:, :-1]], dim=1)\n\n            head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)\n\n            vel = motion_vector_a.clone() / (0.1 * self.shift)\n            vel[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0\n            motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0\n            x_a = torch.stack(\n                [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),\n                 angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1)\n\n            x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)),\n                               categorical_embs=categorical_embs)\n            x_a = x_a.view(-1, num_step, self.hidden_dim)\n\n            feat_a = torch.cat((agent_token_emb, x_a), dim=-1)\n            feat_a = self.fusion_emb(feat_a)\n\n        agent_valid_mask[agent_category != 3] = False\n\n        return {\n            'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:],\n            'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:],\n            'gt': data['agent']['position'][:, self.num_historical_steps:, :self.input_dim].contiguous(),\n            'valid_mask': agent_valid_mask[:, self.num_historical_steps:],\n            'pred_traj': pred_traj,\n            'pred_head': pred_head,\n            'next_token_idx': torch.cat(next_token_idx_list, dim=-1),\n            'next_token_idx_gt': agent_token_index.roll(shifts=-1, dims=1),\n            'next_token_eval_mask': data['agent']['agent_valid_mask'],\n            'pred_prob': pred_prob,\n            'vel': vel\n        }\n"
  },
  {
    "path": "smart/modules/map_decoder.py",
    "content": "import os.path\nfrom typing import Dict\nimport torch\nimport torch.nn as nn\nfrom torch_cluster import radius_graph\nfrom torch_geometric.data import Batch\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.utils import dense_to_sparse, subgraph\nfrom smart.utils.nan_checker import check_nan_inf\nfrom smart.layers.attention_layer import AttentionLayer\nfrom smart.layers import MLPLayer\nfrom smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding\nfrom smart.utils import angle_between_2d_vectors\nfrom smart.utils import merge_edges\nfrom smart.utils import weight_init\nfrom smart.utils import wrap_angle\nimport pickle\n\n\nclass SMARTMapDecoder(nn.Module):\n\n    def __init__(self,\n                 dataset: str,\n                 input_dim: int,\n                 hidden_dim: int,\n                 num_historical_steps: int,\n                 pl2pl_radius: float,\n                 num_freq_bands: int,\n                 num_layers: int,\n                 num_heads: int,\n                 head_dim: int,\n                 dropout: float,\n                 map_token) -> None:\n        super(SMARTMapDecoder, self).__init__()\n        self.dataset = dataset\n        self.input_dim = input_dim\n        self.hidden_dim = hidden_dim\n        self.num_historical_steps = num_historical_steps\n        self.pl2pl_radius = pl2pl_radius\n        self.num_freq_bands = num_freq_bands\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n        self.head_dim = head_dim\n        self.dropout = dropout\n\n        if input_dim == 2:\n            input_dim_r_pt2pt = 3\n        elif input_dim == 3:\n            input_dim_r_pt2pt = 4\n        else:\n            raise ValueError('{} is not a valid dimension'.format(input_dim))\n\n        self.type_pt_emb = nn.Embedding(17, hidden_dim)\n        self.side_pt_emb = nn.Embedding(4, hidden_dim)\n        self.polygon_type_emb = nn.Embedding(4, hidden_dim)\n        self.light_pl_emb = nn.Embedding(4, hidden_dim)\n\n        self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim,\n                                            num_freq_bands=num_freq_bands)\n        self.pt2pt_layers = nn.ModuleList(\n            [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,\n                            bipartite=False, has_pos_emb=True) for _ in range(num_layers)]\n        )\n        self.token_size = 1024\n        self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,\n                                           output_dim=self.token_size)\n        input_dim_token = 22\n        self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)\n        self.map_token = map_token\n        self.apply(weight_init)\n        self.mask_pt = False\n\n    def maybe_autocast(self, dtype=torch.float32):\n        return torch.cuda.amp.autocast(dtype=dtype)\n\n    def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:\n        pt_valid_mask = data['pt_token']['pt_valid_mask']\n        pt_pred_mask = data['pt_token']['pt_pred_mask']\n        pt_target_mask = data['pt_token']['pt_target_mask']\n        mask_s = pt_valid_mask\n\n        pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous()\n        orient_pt = data['pt_token']['orientation'].contiguous()\n        orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1)\n        token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float)\n        pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1))\n        pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']]\n\n        if self.input_dim == 2:\n            x_pt = pt_token_emb\n        elif self.input_dim == 3:\n            x_pt = pt_token_emb\n        else:\n            raise ValueError('{} is not a valid dimension'.format(self.input_dim))\n\n        token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index']\n        token_light_type = data['map_polygon']['light_type'][token2pl[1]]\n        x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()),\n                                 self.polygon_type_emb(data['pt_token']['pl_type'].long()),\n                                 self.light_pl_emb(token_light_type.long()),]\n        x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0)\n        edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius,\n                                        batch=data['pt_token']['batch'] if isinstance(data, Batch) else None,\n                                        loop=False, max_num_neighbors=100)\n        if self.mask_pt:\n            edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0]\n        rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]]\n        rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]])\n        if self.input_dim == 2:\n            r_pt2pt = torch.stack(\n                [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),\n                 angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],\n                                          nbr_vector=rel_pos_pt2pt[:, :2]),\n                 rel_orient_pt2pt], dim=-1)\n        elif self.input_dim == 3:\n            r_pt2pt = torch.stack(\n                [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),\n                 angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],\n                                          nbr_vector=rel_pos_pt2pt[:, :2]),\n                 rel_pos_pt2pt[:, -1],\n                 rel_orient_pt2pt], dim=-1)\n        else:\n            raise ValueError('{} is not a valid dimension'.format(self.input_dim))\n        r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None)\n        for i in range(self.num_layers):\n            x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt)\n\n        next_token_prob = self.token_predict_head(x_pt[pt_pred_mask])\n        next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)\n        _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)\n        next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask]\n\n        return {\n            'x_pt': x_pt,\n            'map_next_token_idx': next_token_idx,\n            'map_next_token_prob': next_token_prob,\n            'map_next_token_idx_gt': next_token_index_gt,\n            'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask]\n        }\n"
  },
  {
    "path": "smart/modules/smart_decoder.py",
    "content": "from typing import Dict, Optional\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.data import HeteroData\nfrom smart.modules.agent_decoder import SMARTAgentDecoder\nfrom smart.modules.map_decoder import SMARTMapDecoder\n\n\nclass SMARTDecoder(nn.Module):\n\n    def __init__(self,\n                 dataset: str,\n                 input_dim: int,\n                 hidden_dim: int,\n                 num_historical_steps: int,\n                 pl2pl_radius: float,\n                 time_span: Optional[int],\n                 pl2a_radius: float,\n                 a2a_radius: float,\n                 num_freq_bands: int,\n                 num_map_layers: int,\n                 num_agent_layers: int,\n                 num_heads: int,\n                 head_dim: int,\n                 dropout: float,\n                 map_token: Dict,\n                 token_data: Dict,\n                 use_intention=False,\n                 token_size=512) -> None:\n        super(SMARTDecoder, self).__init__()\n        self.map_encoder = SMARTMapDecoder(\n            dataset=dataset,\n            input_dim=input_dim,\n            hidden_dim=hidden_dim,\n            num_historical_steps=num_historical_steps,\n            pl2pl_radius=pl2pl_radius,\n            num_freq_bands=num_freq_bands,\n            num_layers=num_map_layers,\n            num_heads=num_heads,\n            head_dim=head_dim,\n            dropout=dropout,\n            map_token=map_token\n        )\n        self.agent_encoder = SMARTAgentDecoder(\n            dataset=dataset,\n            input_dim=input_dim,\n            hidden_dim=hidden_dim,\n            num_historical_steps=num_historical_steps,\n            time_span=time_span,\n            pl2a_radius=pl2a_radius,\n            a2a_radius=a2a_radius,\n            num_freq_bands=num_freq_bands,\n            num_layers=num_agent_layers,\n            num_heads=num_heads,\n            head_dim=head_dim,\n            dropout=dropout,\n            token_size=token_size,\n            token_data=token_data\n        )\n        self.map_enc = None\n\n    def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:\n        map_enc = self.map_encoder(data)\n        agent_enc = self.agent_encoder(data, map_enc)\n        return {**map_enc, **agent_enc}\n\n    def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:\n        map_enc = self.map_encoder(data)\n        agent_enc = self.agent_encoder.inference(data, map_enc)\n        return {**map_enc, **agent_enc}\n\n    def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]:\n        agent_enc = self.agent_encoder.inference(data, map_enc)\n        return {**map_enc, **agent_enc}\n"
  },
  {
    "path": "smart/preprocess/__init__.py",
    "content": ""
  },
  {
    "path": "smart/preprocess/preprocess.py",
    "content": "import numpy as np\nimport pandas as pd\nimport os\nimport torch\nfrom typing import Any, Dict, List, Optional\n\npredict_unseen_agents = False\nvector_repr = True\n_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']\n_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']\n_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']\n_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',\n                'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',\n                'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',\n                'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']\n_point_sides = ['LEFT', 'RIGHT', 'CENTER']\n_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']\n_polygon_is_intersections = [True, False, None]\n\n\nLane_type_hash = {\n    4: \"BIKE\",\n    3: \"VEHICLE\",\n    2: \"VEHICLE\",\n    1: \"BUS\"\n}\n\nboundary_type_hash = {\n        5: \"UNKNOWN\",\n        6: \"DASHED_WHITE\",\n        7: \"SOLID_WHITE\",\n        8: \"DOUBLE_DASH_WHITE\",\n        9: \"DASHED_YELLOW\",\n        10: \"DOUBLE_DASH_YELLOW\",\n        11: \"SOLID_YELLOW\",\n        12: \"DOUBLE_SOLID_YELLOW\",\n        13: \"DASH_SOLID_YELLOW\",\n        14: \"UNKNOWN\",\n        15: \"EDGE\",\n        16: \"EDGE\"\n}\n\n\ndef get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:\n    if not predict_unseen_agents:  # filter out agents that are unseen during the historical time steps\n        historical_df = df[df['timestep'] == num_historical_steps-1]\n        agent_ids = list(historical_df['track_id'].unique())\n        df = df[df['track_id'].isin(agent_ids)]\n    else:\n        agent_ids = list(df['track_id'].unique())\n\n    num_agents = len(agent_ids)\n    # initialization\n    valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)\n    current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)\n    predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)\n    agent_id: List[Optional[str]] = [None] * num_agents\n    agent_type = torch.zeros(num_agents, dtype=torch.uint8)\n    agent_category = torch.zeros(num_agents, dtype=torch.uint8)\n    position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)\n    heading = torch.zeros(num_agents, num_steps, dtype=torch.float)\n    velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)\n    shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)\n\n    for track_id, track_df in df.groupby('track_id'):\n        agent_idx = agent_ids.index(track_id)\n        agent_steps = track_df['timestep'].values\n\n        valid_mask[agent_idx, agent_steps] = True\n        current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]\n        predict_mask[agent_idx, agent_steps] = True\n        if vector_repr:  # a time step t is valid only when both t and t-1 are valid\n            valid_mask[agent_idx, 1: num_historical_steps] = (\n                valid_mask[agent_idx, :num_historical_steps - 1] &\n                valid_mask[agent_idx, 1: num_historical_steps])\n            valid_mask[agent_idx, 0] = False\n        predict_mask[agent_idx, :num_historical_steps] = False\n        if not current_valid_mask[agent_idx]:\n            predict_mask[agent_idx, num_historical_steps:] = False\n\n        agent_id[agent_idx] = track_id\n        agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])\n        agent_category[agent_idx] = track_df['object_category'].values[0]\n        position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,\n                                                                          track_df['position_y'].values,\n                                                                          track_df['position_z'].values],\n                                                                         axis=-1)).float()\n        heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()\n        velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,\n                                                                          track_df['velocity_y'].values],\n                                                                         axis=-1)).float()\n        shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,\n                                                                       track_df['width'].values,\n                                                                       track_df[\"height\"].values],\n                                                                      axis=-1)).float()\n    av_idx = agent_id.index(av_id)\n\n    return {\n        'num_nodes': num_agents,\n        'av_index': av_idx,\n        'valid_mask': valid_mask,\n        'predict_mask': predict_mask,\n        'id': agent_id,\n        'type': agent_type,\n        'category': agent_category,\n        'position': position,\n        'heading': heading,\n        'velocity': velocity,\n        'shape': shape\n    }"
  },
  {
    "path": "smart/tokens/__init__.py",
    "content": ""
  },
  {
    "path": "smart/transforms/__init__.py",
    "content": "from smart.transforms.target_builder import WaymoTargetBuilder\n"
  },
  {
    "path": "smart/transforms/target_builder.py",
    "content": "\nimport numpy as np\nimport torch\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.transforms import BaseTransform\nfrom smart.utils import wrap_angle\nfrom smart.utils.log import Logging\n\n\ndef to_16(data):\n    if isinstance(data, dict):\n        for key, value in data.items():\n            new_value = to_16(value)\n            data[key] = new_value\n    if isinstance(data, torch.Tensor):\n        if data.dtype == torch.float32:\n            data = data.to(torch.float16)\n    return data\n\n\ndef tofloat32(data):\n    for name in data:\n        value = data[name]\n        if isinstance(value, dict):\n            value = tofloat32(value)\n        elif isinstance(value, torch.Tensor) and value.dtype == torch.float64:\n            value = value.to(torch.float32)\n        data[name] = value\n    return data\n\n\nclass WaymoTargetBuilder(BaseTransform):\n\n    def __init__(self,\n                 num_historical_steps: int,\n                 num_future_steps: int,\n                 mode=\"train\") -> None:\n        self.num_historical_steps = num_historical_steps\n        self.num_future_steps = num_future_steps\n        self.mode = mode\n        self.num_features = 3\n        self.augment = False\n        self.logger = Logging().log(level='DEBUG')\n\n    def score_ego_agent(self, agent):\n        av_index = agent['av_index']\n        agent[\"category\"][av_index] = 5\n        return agent\n\n    def clip(self, agent, max_num=32):\n        av_index = agent[\"av_index\"]\n        valid = agent['valid_mask']\n        ego_pos = agent[\"position\"][av_index]\n        obstacle_mask = agent['type'] == 3\n        distance = torch.norm(agent[\"position\"][:, self.num_historical_steps-1, :2] - ego_pos[self.num_historical_steps-1, :2], dim=-1)  # keep the closest 100 vehicles near the ego car\n        distance[obstacle_mask] = 10e5\n        sort_idx = distance.sort()[1]\n        mask = torch.zeros(valid.shape[0])\n        mask[sort_idx[:max_num]] = 1\n        mask = mask.to(torch.bool)\n        mask[av_index] = True\n        new_av_index = mask[:av_index].sum()\n        agent[\"num_nodes\"] = int(mask.sum())\n        agent[\"av_index\"] = int(new_av_index)\n        excluded = [\"num_nodes\", \"av_index\", \"ego\"]\n        for key, val in agent.items():\n            if key in excluded:\n                continue\n            if key == \"id\":\n                val = list(np.array(val)[mask])\n                agent[key] = val\n                continue\n            if len(val.size()) > 1:\n                agent[key] = val[mask, ...]\n            else:\n                agent[key] = val[mask]\n        return agent\n\n    def score_nearby_vehicle(self, agent, max_num=10):\n        av_index = agent['av_index']\n        agent[\"category\"] = torch.zeros_like(agent[\"category\"])\n        obstacle_mask = agent['type'] == 3\n        pos = agent[\"position\"][av_index, self.num_historical_steps, :2]\n        distance = torch.norm(agent[\"position\"][:, self.num_historical_steps, :2] - pos, dim=-1)\n        distance[obstacle_mask] = 10e5\n        sort_idx = distance.sort()[1]\n        nearby_mask = torch.zeros(distance.shape[0])\n        nearby_mask[sort_idx[1:max_num]] = 1\n        nearby_mask = nearby_mask.bool()\n        agent[\"category\"][nearby_mask] = 3\n        agent[\"category\"][obstacle_mask] = 0\n\n    def score_trained_vehicle(self, agent, max_num=10, min_distance=0):\n        av_index = agent['av_index']\n        agent[\"category\"] = torch.zeros_like(agent[\"category\"])\n        pos = agent[\"position\"][av_index, self.num_historical_steps, :2]\n        distance = torch.norm(agent[\"position\"][:, self.num_historical_steps, :2] - pos, dim=-1)\n        distance_all_time = torch.norm(agent[\"position\"][:, :, :2] - agent[\"position\"][av_index, :, :2], dim=-1)\n        invalid_mask = distance_all_time < 150  # we do not believe the perception out of range of 150 meters\n        agent[\"valid_mask\"] = agent[\"valid_mask\"] * invalid_mask\n        # we do not predict vehicle  too far away from ego car\n        closet_vehicle = distance < 100\n        valid = agent['valid_mask']\n        valid_current = valid[:, (self.num_historical_steps):]\n        valid_counts = valid_current.sum(1)\n        counts_vehicle = valid_counts >= 1\n        no_backgroud = agent['type'] != 3\n        vehicle2pred = closet_vehicle & counts_vehicle & no_backgroud\n        if vehicle2pred.sum() > max_num:\n            # too many still vehicle so that train the model using the moving vehicle as much as possible\n            true_indices = torch.nonzero(vehicle2pred).squeeze(1)\n            selected_indices = true_indices[torch.randperm(true_indices.size(0))[:max_num]]\n            vehicle2pred.fill_(False)\n            vehicle2pred[selected_indices] = True\n        agent[\"category\"][vehicle2pred] = 3\n\n    def rotate_agents(self, position, heading, num_nodes, num_historical_steps, num_future_steps):\n        origin = position[:, num_historical_steps - 1]\n        theta = heading[:, num_historical_steps - 1]\n        cos, sin = theta.cos(), theta.sin()\n        rot_mat = theta.new_zeros(num_nodes, 2, 2)\n        rot_mat[:, 0, 0] = cos\n        rot_mat[:, 0, 1] = -sin\n        rot_mat[:, 1, 0] = sin\n        rot_mat[:, 1, 1] = cos\n        target = origin.new_zeros(num_nodes, num_future_steps, 4)\n        target[..., :2] = torch.bmm(position[:, num_historical_steps:, :2] -\n                                    origin[:, :2].unsqueeze(1), rot_mat)\n        his = origin.new_zeros(num_nodes, num_historical_steps, 4)\n        his[..., :2] = torch.bmm(position[:, :num_historical_steps, :2] -\n                                 origin[:, :2].unsqueeze(1), rot_mat)\n        if position.size(2) == 3:\n            target[..., 2] = (position[:, num_historical_steps:, 2] -\n                              origin[:, 2].unsqueeze(-1))\n            his[..., 2] = (position[:, :num_historical_steps, 2] -\n                           origin[:, 2].unsqueeze(-1))\n            target[..., 3] = wrap_angle(heading[:, num_historical_steps:] -\n                                        theta.unsqueeze(-1))\n            his[..., 3] = wrap_angle(heading[:, :num_historical_steps] -\n                                     theta.unsqueeze(-1))\n        else:\n            target[..., 2] = wrap_angle(heading[:, num_historical_steps:] -\n                                        theta.unsqueeze(-1))\n            his[..., 2] = wrap_angle(heading[:, :num_historical_steps] -\n                                     theta.unsqueeze(-1))\n        return his, target\n\n    def __call__(self, data) -> HeteroData:\n        agent = data[\"agent\"]\n        self.score_ego_agent(agent)\n        self.score_trained_vehicle(agent, max_num=32)\n        return HeteroData(data)\n"
  },
  {
    "path": "smart/utils/__init__.py",
    "content": "\nfrom smart.utils.geometry import angle_between_2d_vectors\nfrom smart.utils.geometry import angle_between_3d_vectors\nfrom smart.utils.geometry import side_to_directed_lineseg\nfrom smart.utils.geometry import wrap_angle\nfrom smart.utils.graph import add_edges\nfrom smart.utils.graph import bipartite_dense_to_sparse\nfrom smart.utils.graph import complete_graph\nfrom smart.utils.graph import merge_edges\nfrom smart.utils.graph import unbatch\nfrom smart.utils.list import safe_list_index\nfrom smart.utils.weight_init import weight_init\n"
  },
  {
    "path": "smart/utils/cluster_reader.py",
    "content": "import io\nimport pickle\nimport pandas as pd\nimport json\n\n\nclass LoadScenarioFromCeph:\n    def __init__(self):\n        from petrel_client.client import Client\n        self.file_client = Client('~/petreloss.conf')\n\n    def list(self, dir_path):\n        return list(self.file_client.list(dir_path))\n\n    def save(self, data, url):\n        self.file_client.put(url, pickle.dumps(data))\n\n    def read_correct_csv(self, scenario_path):\n        output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine=\"python\")\n        return output\n\n    def contains(self, url):\n        return self.file_client.contains(url)\n\n    def read_string(self, csv_url):\n        from io import StringIO\n        df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\\s+', low_memory=False)\n        return df\n\n    def read(self, scenario_path):\n        with io.BytesIO(self.file_client.get(scenario_path)) as f:\n            datas = pickle.load(f)\n            return datas\n\n    def read_json(self, path):\n        with io.BytesIO(self.file_client.get(path)) as f:\n            data = json.load(f)\n            return data\n\n    def read_csv(self, scenario_path):\n        return pickle.loads(self.file_client.get(scenario_path))\n\n    def read_model(self, model_path):\n        with io.BytesIO(self.file_client.get(model_path)) as f:\n            pass\n"
  },
  {
    "path": "smart/utils/config.py",
    "content": "import os\nimport yaml\nimport easydict\n\n\ndef load_config_act(path):\n    \"\"\" load config file\"\"\"\n    with open(path, 'r') as f:\n        cfg = yaml.load(f, Loader=yaml.FullLoader)\n    return easydict.EasyDict(cfg)\n\n\ndef load_config_init(path):\n    \"\"\" load config file\"\"\"\n    path = os.path.join('init/configs', f'{path}.yaml')\n    with open(path, 'r') as f:\n        cfg = yaml.load(f, Loader=yaml.FullLoader)\n    return cfg\n"
  },
  {
    "path": "smart/utils/geometry.py",
    "content": "\nimport math\n\nimport torch\n\n\ndef angle_between_2d_vectors(\n        ctr_vector: torch.Tensor,\n        nbr_vector: torch.Tensor) -> torch.Tensor:\n    return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0],\n                       (ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1))\n\n\ndef angle_between_3d_vectors(\n        ctr_vector: torch.Tensor,\n        nbr_vector: torch.Tensor) -> torch.Tensor:\n    return torch.atan2(torch.cross(ctr_vector, nbr_vector, dim=-1).norm(p=2, dim=-1),\n                       (ctr_vector * nbr_vector).sum(dim=-1))\n\n\ndef side_to_directed_lineseg(\n        query_point: torch.Tensor,\n        start_point: torch.Tensor,\n        end_point: torch.Tensor) -> str:\n    cond = ((end_point[0] - start_point[0]) * (query_point[1] - start_point[1]) -\n            (end_point[1] - start_point[1]) * (query_point[0] - start_point[0]))\n    if cond > 0:\n        return 'LEFT'\n    elif cond < 0:\n        return 'RIGHT'\n    else:\n        return 'CENTER'\n\n\ndef wrap_angle(\n        angle: torch.Tensor,\n        min_val: float = -math.pi,\n        max_val: float = math.pi) -> torch.Tensor:\n    return min_val + (angle + max_val) % (max_val - min_val)\n"
  },
  {
    "path": "smart/utils/graph.py",
    "content": "\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch_geometric.utils import coalesce\nfrom torch_geometric.utils import degree\n\n\ndef add_edges(\n        from_edge_index: torch.Tensor,\n        to_edge_index: torch.Tensor,\n        from_edge_attr: Optional[torch.Tensor] = None,\n        to_edge_attr: Optional[torch.Tensor] = None,\n        replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n    from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype)\n    mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) &\n            (to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0)))\n    if replace:\n        to_mask = mask.any(dim=1)\n        if from_edge_attr is not None and to_edge_attr is not None:\n            from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)\n            to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0)\n        to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1)\n    else:\n        from_mask = mask.any(dim=0)\n        if from_edge_attr is not None and to_edge_attr is not None:\n            from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)\n            to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0)\n        to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1)\n    return to_edge_index, to_edge_attr\n\n\ndef merge_edges(\n        edge_indices: List[torch.Tensor],\n        edge_attrs: Optional[List[torch.Tensor]] = None,\n        reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n    edge_index = torch.cat(edge_indices, dim=1)\n    if edge_attrs is not None:\n        edge_attr = torch.cat(edge_attrs, dim=0)\n    else:\n        edge_attr = None\n    return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce)\n\n\ndef complete_graph(\n        num_nodes: Union[int, Tuple[int, int]],\n        ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,\n        loop: bool = False,\n        device: Optional[Union[torch.device, str]] = None) -> torch.Tensor:\n    if ptr is None:\n        if isinstance(num_nodes, int):\n            num_src, num_dst = num_nodes, num_nodes\n        else:\n            num_src, num_dst = num_nodes\n        edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),\n                                          torch.arange(num_dst, dtype=torch.long, device=device)).t()\n    else:\n        if isinstance(ptr, torch.Tensor):\n            ptr_src, ptr_dst = ptr, ptr\n            num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1]\n        else:\n            ptr_src, ptr_dst = ptr\n            num_src_batch = ptr_src[1:] - ptr_src[:-1]\n            num_dst_batch = ptr_dst[1:] - ptr_dst[:-1]\n        edge_index = torch.cat(\n            [torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),\n                                  torch.arange(num_dst, dtype=torch.long, device=device)) + p\n             for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))],\n            dim=0)\n        edge_index = edge_index.t()\n    if isinstance(num_nodes, int) and not loop:\n        edge_index = edge_index[:, edge_index[0] != edge_index[1]]\n    return edge_index.contiguous()\n\n\ndef bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor:\n    index = adj.nonzero(as_tuple=True)\n    if len(index) == 3:\n        batch_src = index[0] * adj.size(1)\n        batch_dst = index[0] * adj.size(2)\n        index = (batch_src + index[1], batch_dst + index[2])\n    return torch.stack(index, dim=0)\n\n\ndef unbatch(\n        src: torch.Tensor,\n        batch: torch.Tensor,\n        dim: int = 0) -> List[torch.Tensor]:\n    sizes = degree(batch, dtype=torch.long).tolist()\n    return src.split(sizes, dim)\n"
  },
  {
    "path": "smart/utils/list.py",
    "content": "\nfrom typing import Any, List, Optional\n\n\ndef safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:\n    try:\n        return ls.index(elem)\n    except ValueError:\n        return None\n"
  },
  {
    "path": "smart/utils/log.py",
    "content": "import logging\nimport time\nimport os\n\n\nclass Logging:\n\n    def make_log_dir(self, dirname='logs'):\n        now_dir = os.path.dirname(__file__)\n        path = os.path.join(now_dir, dirname)\n        path = os.path.normpath(path)\n        if not os.path.exists(path):\n            os.mkdir(path)\n        return path\n\n    def get_log_filename(self):\n        filename = \"{}.log\".format(time.strftime(\"%Y-%m-%d\",time.localtime()))\n        filename = os.path.join(self.make_log_dir(), filename)\n        filename = os.path.normpath(filename)\n        return filename\n\n    def log(self, level='DEBUG', name=\"simagent\"):\n        logger = logging.getLogger(name)\n        level = getattr(logging, level)\n        logger.setLevel(level)\n        if not logger.handlers:\n            sh = logging.StreamHandler()\n            fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding=\"utf-8\")\n            fmt = logging.Formatter(\"%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s\")\n            sh.setFormatter(fmt=fmt)\n            fh.setFormatter(fmt=fmt)\n            logger.addHandler(sh)\n            logger.addHandler(fh)\n        return logger\n\n    def add_log(self, logger, level='DEBUG'):\n        level = getattr(logging, level)\n        logger.setLevel(level)\n        if not logger.handlers:\n            sh = logging.StreamHandler()\n            fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding=\"utf-8\")\n            fmt = logging.Formatter(\"%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s\")\n            sh.setFormatter(fmt=fmt)\n            fh.setFormatter(fmt=fmt)\n            logger.addHandler(sh)\n            logger.addHandler(fh)\n        return logger\n\n\nif __name__ == '__main__':\n    logger = Logging().log(level='INFO')\n    logger.debug(\"1111111111111111111111\") #使用日志器生成日志\n    logger.info(\"222222222222222222222222\")\n    logger.error(\"附件为IP飞机外婆家二分IP文件放\")\n    logger.warning(\"3333333333333333333333333333\")\n    logger.critical(\"44444444444444444444444444\")\n"
  },
  {
    "path": "smart/utils/nan_checker.py",
    "content": "import torch\n\ndef check_nan_inf(t, s):\n    assert not torch.isinf(t).any(), f\"{s} is inf, {t}\"\n    assert not torch.isnan(t).any(), f\"{s} is nan, {t}\""
  },
  {
    "path": "smart/utils/weight_init.py",
    "content": "\nimport torch.nn as nn\n\n\ndef weight_init(m: nn.Module) -> None:\n    if isinstance(m, nn.Linear):\n        nn.init.xavier_uniform_(m.weight)\n        if m.bias is not None:\n            nn.init.zeros_(m.bias)\n    elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n        fan_in = m.in_channels / m.groups\n        fan_out = m.out_channels / m.groups\n        bound = (6.0 / (fan_in + fan_out)) ** 0.5\n        nn.init.uniform_(m.weight, -bound, bound)\n        if m.bias is not None:\n            nn.init.zeros_(m.bias)\n    elif isinstance(m, nn.Embedding):\n        nn.init.normal_(m.weight, mean=0.0, std=0.02)\n    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):\n        nn.init.ones_(m.weight)\n        nn.init.zeros_(m.bias)\n    elif isinstance(m, nn.LayerNorm):\n        nn.init.ones_(m.weight)\n        nn.init.zeros_(m.bias)\n    elif isinstance(m, nn.MultiheadAttention):\n        if m.in_proj_weight is not None:\n            fan_in = m.embed_dim\n            fan_out = m.embed_dim\n            bound = (6.0 / (fan_in + fan_out)) ** 0.5\n            nn.init.uniform_(m.in_proj_weight, -bound, bound)\n        else:\n            nn.init.xavier_uniform_(m.q_proj_weight)\n            nn.init.xavier_uniform_(m.k_proj_weight)\n            nn.init.xavier_uniform_(m.v_proj_weight)\n        if m.in_proj_bias is not None:\n            nn.init.zeros_(m.in_proj_bias)\n        nn.init.xavier_uniform_(m.out_proj.weight)\n        if m.out_proj.bias is not None:\n            nn.init.zeros_(m.out_proj.bias)\n        if m.bias_k is not None:\n            nn.init.normal_(m.bias_k, mean=0.0, std=0.02)\n        if m.bias_v is not None:\n            nn.init.normal_(m.bias_v, mean=0.0, std=0.02)\n    elif isinstance(m, (nn.LSTM, nn.LSTMCell)):\n        for name, param in m.named_parameters():\n            if 'weight_ih' in name:\n                for ih in param.chunk(4, 0):\n                    nn.init.xavier_uniform_(ih)\n            elif 'weight_hh' in name:\n                for hh in param.chunk(4, 0):\n                    nn.init.orthogonal_(hh)\n            elif 'weight_hr' in name:\n                nn.init.xavier_uniform_(param)\n            elif 'bias_ih' in name:\n                nn.init.zeros_(param)\n            elif 'bias_hh' in name:\n                nn.init.zeros_(param)\n                nn.init.ones_(param.chunk(4, 0)[1])\n    elif isinstance(m, (nn.GRU, nn.GRUCell)):\n        for name, param in m.named_parameters():\n            if 'weight_ih' in name:\n                for ih in param.chunk(3, 0):\n                    nn.init.xavier_uniform_(ih)\n            elif 'weight_hh' in name:\n                for hh in param.chunk(3, 0):\n                    nn.init.orthogonal_(hh)\n            elif 'bias_ih' in name:\n                nn.init.zeros_(param)\n            elif 'bias_hh' in name:\n                nn.init.zeros_(param)\n"
  },
  {
    "path": "train.py",
    "content": "\nfrom argparse import ArgumentParser\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import LearningRateMonitor\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom pytorch_lightning.strategies import DDPStrategy\nfrom smart.utils.config import load_config_act\nfrom smart.datamodules import MultiDataModule\nfrom smart.model import SMART\nfrom smart.utils.log import Logging\n\n\nif __name__ == '__main__':\n    parser = ArgumentParser()\n    Predictor_hash = {\"smart\": SMART, }\n    parser.add_argument('--config', type=str, default='configs/train/train_scalable.yaml')\n    parser.add_argument('--pretrain_ckpt', type=str, default=\"\")\n    parser.add_argument('--ckpt_path', type=str, default=\"\")\n    parser.add_argument('--save_ckpt_path', type=str, default=\"\")\n    args = parser.parse_args()\n    config = load_config_act(args.config)\n    Predictor = Predictor_hash[config.Model.predictor]\n    strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True)\n    Data_config = config.Dataset\n    datamodule = MultiDataModule(**vars(Data_config))\n\n    if args.pretrain_ckpt == \"\":\n        model = Predictor(config.Model)\n    else:\n        logger = Logging().log(level='DEBUG')\n        model = Predictor(config.Model)\n        model.load_params_from_file(filename=args.pretrain_ckpt,\n                                    logger=logger)\n    trainer_config = config.Trainer\n    model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path,\n                                       filename=\"{epoch:02d}\",\n                                       monitor='val_cls_acc',\n                                       every_n_epochs=1,\n                                       save_top_k=5,\n                                       mode='max')\n    lr_monitor = LearningRateMonitor(logging_interval='epoch')\n    trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=trainer_config.devices,\n                         strategy=strategy,\n                         accumulate_grad_batches=trainer_config.accumulate_grad_batches,\n                         num_nodes=trainer_config.num_nodes,\n                         callbacks=[model_checkpoint, lr_monitor],\n                         max_epochs=trainer_config.max_epochs,\n                         num_sanity_val_steps=0,\n                         gradient_clip_val=0.5)\n    if args.ckpt_path == \"\":\n        trainer.fit(model,\n                    datamodule)\n    else:\n        trainer.fit(model,\n                    datamodule,\n                    ckpt_path=args.ckpt_path)\n"
  },
  {
    "path": "val.py",
    "content": "\nfrom argparse import ArgumentParser\nimport pytorch_lightning as pl\nfrom torch_geometric.loader import DataLoader\nfrom smart.datasets.scalable_dataset import MultiDataset\nfrom smart.model import SMART\nfrom smart.transforms import WaymoTargetBuilder\nfrom smart.utils.config import load_config_act\nfrom smart.utils.log import Logging\n\nif __name__ == '__main__':\n    pl.seed_everything(2, workers=True)\n    parser = ArgumentParser()\n    parser.add_argument('--config', type=str, default=\"configs/validation/validation_scalable.yaml\")\n    parser.add_argument('--pretrain_ckpt', type=str, default=\"\")\n    parser.add_argument('--ckpt_path', type=str, default=\"\")\n    parser.add_argument('--save_ckpt_path', type=str, default=\"\")\n    args = parser.parse_args()\n    config = load_config_act(args.config)\n\n    data_config = config.Dataset\n    val_dataset = {\n        \"scalable\": MultiDataset,\n    }[data_config.dataset](root=data_config.root, split='val',\n                           raw_dir=data_config.val_raw_dir,\n                           processed_dir=data_config.val_processed_dir,\n                           transform=WaymoTargetBuilder(config.Model.num_historical_steps, config.Model.decoder.num_future_steps))\n    dataloader = DataLoader(val_dataset, batch_size=data_config.batch_size, shuffle=False, num_workers=data_config.num_workers,\n                            pin_memory=data_config.pin_memory, persistent_workers=True if data_config.num_workers > 0 else False)\n    Predictor = SMART\n    if args.pretrain_ckpt == \"\":\n        model = Predictor(config.Model)\n    else:\n        logger = Logging().log(level='DEBUG')\n        model = Predictor(config.Model)\n        model.load_params_from_file(filename=args.pretrain_ckpt,\n                                    logger=logger)\n\n    trainer_config = config.Trainer\n    trainer = pl.Trainer(accelerator=trainer_config.accelerator,\n                         devices=trainer_config.devices,\n                         strategy='ddp', num_sanity_val_steps=0)\n    trainer.validate(model, dataloader)\n"
  }
]