Full Code of rainmaker22/SMART for AI

main 42e658542b03 cached
52 files
205.4 KB
54.1k tokens
167 symbols
1 requests
Download .txt
Showing preview only (219K chars total). Download the full file or copy to clipboard to get everything.
Repository: rainmaker22/SMART
Branch: main
Commit: 42e658542b03
Files: 52
Total size: 205.4 KB

Directory structure:
gitextract_98xwsw9y/

├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── configs/
│   ├── train/
│   │   └── train_scalable.yaml
│   └── validation/
│       └── validation_scalable.yaml
├── data_preprocess.py
├── environment.yml
├── pyproject.toml
├── requirements.txt
├── scripts/
│   ├── install_pyg.sh
│   └── traj_clstering.py
├── smart/
│   ├── __init__.py
│   ├── datamodules/
│   │   ├── __init__.py
│   │   └── scalable_datamodule.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── preprocess.py
│   │   └── scalable_dataset.py
│   ├── layers/
│   │   ├── __init__.py
│   │   ├── attention_layer.py
│   │   ├── fourier_embedding.py
│   │   └── mlp_layer.py
│   ├── metrics/
│   │   ├── __init__.py
│   │   ├── average_meter.py
│   │   ├── min_ade.py
│   │   ├── min_fde.py
│   │   ├── next_token_cls.py
│   │   └── utils.py
│   ├── model/
│   │   ├── __init__.py
│   │   └── smart.py
│   ├── modules/
│   │   ├── __init__.py
│   │   ├── agent_decoder.py
│   │   ├── map_decoder.py
│   │   └── smart_decoder.py
│   ├── preprocess/
│   │   ├── __init__.py
│   │   └── preprocess.py
│   ├── tokens/
│   │   ├── __init__.py
│   │   ├── cluster_frame_5_2048.pkl
│   │   └── map_traj_token5.pkl
│   ├── transforms/
│   │   ├── __init__.py
│   │   └── target_builder.py
│   └── utils/
│       ├── __init__.py
│       ├── cluster_reader.py
│       ├── config.py
│       ├── geometry.py
│       ├── graph.py
│       ├── list.py
│       ├── log.py
│       ├── nan_checker.py
│       └── weight_init.py
├── train.py
└── val.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.github
ckpt/
# assets/
# C extensions
*.so
# /assets
/data
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
*.jpg
env/
venv/
ENV/
env.bak/
venv.bak/
*.jpg
pyg_depend/
# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# IDEs
.idea
.vscode

# seed project
av2/
lightning_logs/
lightning_logs_/
lightning_l/
.DS_Store
data/argo
data/res
data/waymo*
fig*/
data/waymo_token
data/submission
data/token_seq_emb_nuplan
data/token_seq_emb_waymo
data/nuplan*
submission.tar.gz
data/feat*
data/scalable
data/pos_data
res_metrics*
gathered*

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align="center">
  
  # SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction
  
  [Paper](https://arxiv.org/abs/2405.15677) | [Webpage](https://smart-motion.github.io/smart/)

</div>

- **Ranked 1st** on the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/)  
- **Champion** of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)

## News
- **[December 31, 2024]** SMART-Planner achieved state-of-the-art performance on **nuPlan closed-loop planning**
- **[September 26, 2024]** SMART was **accepted to** NeurIPS 2024
- **[August 31, 2024]** Code released
- **[May 24, 2024]** SMART won the championship of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)
- **[May 24, 2024]** SMART paper released on [arxiv](https://arxiv.org/abs/2405.15677)


## Introduction
This repository contains the official implementation of SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction. SMART is a novel autonomous driving motion generation paradigm that models vectorized map and agent trajectory data into discrete sequence tokens.

https://github.com/user-attachments/assets/74a61627-8444-4e54-bb10-d317dd2aacd9

## Requirements

To set up the environment, you can use conda to create and activate a new environment with the necessary dependencies:

```bash
conda env create -f environment.yml
conda activate SMART
pip install -r requirements.txt
```

If you encounter issues while installing pyg dependencies, execute the following script:
```setup
bash install_pyg.sh
```

Alternatively, you can configure the environment in your preferred way. Installing the latest versions of PyTorch, PyG, and PyTorch Lightning should suffice.

## Data installation

**Step 1: Download the Dataset**

Download the Waymo Open Motion Dataset (`scenario protocol` format) and organize the data as follows:
```
SMART
├── data
│   ├── waymo
│   │   ├── scenario
│   │   │   ├──training
│   │   │   ├──validation
│   │   │   ├──testing
├── model
├── tools
```

**Step 2: Install the Waymo Open Dataset API**

Follow the instructions [here](https://github.com/waymo-research/waymo-open-dataset) to install the Waymo Open Dataset API.

**Step 3: Preprocess the Dataset**

Preprocess the dataset by running:
```
python data_preprocess.py --input_dir ./data/waymo/scenario/training  --output_dir ./data/waymo_processed/training
```
The first path is the raw data path, and the second is the output data path.

The processed data will be saved to the `data/waymo_processed/` directory as follows:

```
SMART
├── data
│   ├── waymo_processed
│   │   ├── training
│   │   ├── validation
│   │   ├──testing
├── model
├── utils
```

## Training

To train the model, run the following command:

```train
python train.py --config ${config_path}
```

The default config path is `configs/train/train_scalable.yaml`. Ensure you have downloaded and prepared the Waymo data for training.

## Evaluation

To evaluate the model, run:

```eval
python eval.py --config ${config_path} --pretrain_ckpt ${ckpt_path}
```
This will evaluate the model using the configuration and checkpoint provided.


## Pre-trained Models

To comply with the WOMD participation agreement, we will release the model parameters of a medium-sized model not trained on Waymo data. Users can fine-tune this model with Waymo data as needed.

## Results

### Waymo Open Motion Dataset Sim Agents Challenge

Our model achieves the following performance on the [Waymo Open Motion Dataset Sim Agents Challenge](https://waymo.com/open/challenges/2024/sim-agents/):

| Model name    | Metric Score |
| :-----------: | ------------ |
| SMART-tiny    | 0.7591       |
| SMART-large   | 0.7614       |
| SMART-zeroshot| 0.7210       |

### NuPlan Closed-loop Planning

**SMART-Planner** achieved state-of-the-art performance among learning-based algorithms on **nuPlan closed-loop planning**. The results on val14 are shown below:

![nuPlan Closed-loop Planning](assets/result1.png)

## Citation 

If you find this repository useful, please consider citing our work and giving us a star:

```citation
@article{wu2024smart,
  title={SMART: Scalable Multi-agent Real-time Simulation via Next-token Prediction},
  author={Wu, Wei and Feng, Xiaoxin and Gao, Ziyan and Kan, Yuheng},
  journal={arXiv preprint arXiv:2405.15677},
  year={2024}
}
```

## Acknowledgements
Special thanks to the [QCNET](https://github.com/ZikangZhou/QCNet) repository for providing valuable reference code that significantly influenced this work. 

## License
All code in this repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).


================================================
FILE: __init__.py
================================================



================================================
FILE: configs/train/train_scalable.yaml
================================================
# Config format schema number, the yaml support to valid case source from different dataset
time_info: &time_info
  num_historical_steps: 11
  num_future_steps: 80
  use_intention: True
  token_size: 2048

Dataset:
  root:
  train_batch_size: 1
  val_batch_size: 1
  test_batch_size: 1
  shuffle: True
  num_workers: 1
  pin_memory: True
  persistent_workers: True
  train_raw_dir: ["data/valid_demo"]
  val_raw_dir: ["data/valid_demo"]
  test_raw_dir:
  transform: WaymoTargetBuilder
  train_processed_dir:
  val_processed_dir:
  test_processed_dir:
  dataset: "scalable"
  <<: *time_info

Trainer:
  strategy: ddp_find_unused_parameters_false
  accelerator: "gpu"
  devices: 1
  max_epochs: 32
  save_ckpt_path:
  num_nodes: 1
  mode:
  ckpt_path:
  precision: 32
  accumulate_grad_batches: 1

Model:
  mode: "train"
  predictor: "smart"
  dataset: "waymo"
  input_dim: 2
  hidden_dim: 128
  output_dim: 2
  output_head: False
  num_heads: 8
  <<: *time_info
  head_dim: 16
  dropout: 0.1
  num_freq_bands: 64
  lr: 0.0005
  warmup_steps: 0
  total_steps: 32
  decoder:
    <<: *time_info
    num_map_layers: 3
    num_agent_layers: 6
    a2a_radius: 60
    pl2pl_radius: 10
    pl2a_radius: 30
    time_span: 30


================================================
FILE: configs/validation/validation_scalable.yaml
================================================
# Config format schema number, the yaml support to valid case source from different dataset
time_info: &time_info
  num_historical_steps: 11
  num_future_steps: 80
  token_size: 2048

Dataset:
  root:
  batch_size: 1
  shuffle: True
  num_workers: 1
  pin_memory: True
  persistent_workers: True
  train_raw_dir:
  val_raw_dir: ["data/valid_demo"]
  test_raw_dir:
  TargetBuilder: WaymoTargetBuilder
  train_processed_dir:
  val_processed_dir:
  test_processed_dir:
  dataset: "scalable"
  <<: *time_info

Trainer:
  strategy: ddp_find_unused_parameters_false
  accelerator: "gpu"
  devices: 1
  max_epochs: 32
  save_ckpt_path: 
  num_nodes: 1
  mode:
  ckpt_path: 
  precision: 32
  accumulate_grad_batches: 1

Model:
  mode: "validation"
  predictor: "smart"
  dataset: "waymo"
  input_dim: 2
  hidden_dim: 128
  output_dim: 2
  output_head: False
  num_heads: 8
  <<: *time_info
  head_dim: 16
  dropout: 0.1
  num_freq_bands: 64
  lr: 0.0005
  warmup_steps: 0
  total_steps: 32
  decoder:
    <<: *time_info
    num_map_layers: 3
    num_agent_layers: 6
    a2a_radius: 60
    pl2pl_radius: 10
    pl2a_radius: 30
    time_span: 30



================================================
FILE: data_preprocess.py
================================================
import numpy as np
import pandas as pd
import os
import torch
import pickle
from tqdm import tqdm
from typing import Any, Dict, List, Optional
import easydict

predict_unseen_agents = False
vector_repr = True
root = ''
split = 'train'
raw_dir = os.path.join(root, split, 'raw')
_raw_dir = raw_dir

if os.path.isdir(_raw_dir):
    _raw_file_names = [name for name in os.listdir(_raw_dir)]
else:
    _raw_file_names = []

processed_dir = os.path.join(root, split, 'processed')
_processed_dir = processed_dir
if os.path.isdir(_processed_dir):
    _processed_file_names = [name for name in os.listdir(_processed_dir) if
                             name.endswith(('pkl', 'pickle'))]
else:
    _processed_file_names = []

_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']
_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
                'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
                'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
                'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
_point_sides = ['LEFT', 'RIGHT', 'CENTER']
_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
_polygon_is_intersections = [True, False, None]


Lane_type_hash = {
    4: "BIKE",
    3: "VEHICLE",
    2: "VEHICLE",
    1: "BUS"
}

boundary_type_hash = {
        5: "UNKNOWN",
        6: "DASHED_WHITE",
        7: "SOLID_WHITE",
        8: "DOUBLE_DASH_WHITE",
        9: "DASHED_YELLOW",
        10: "DOUBLE_DASH_YELLOW",
        11: "SOLID_YELLOW",
        12: "DOUBLE_SOLID_YELLOW",
        13: "DASH_SOLID_YELLOW",
        14: "UNKNOWN",
        15: "EDGE",
        16: "EDGE"
}


def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
    try:
        return ls.index(elem)
    except ValueError:
        return None


def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:
    if not predict_unseen_agents:  # filter out agents that are unseen during the historical time steps
        historical_df = df[df['timestep'] == num_historical_steps-1]
        agent_ids = list(historical_df['track_id'].unique())
        df = df[df['track_id'].isin(agent_ids)]
    else:
        agent_ids = list(df['track_id'].unique())

    num_agents = len(agent_ids)
    # initialization
    valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
    current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
    predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
    agent_id: List[Optional[str]] = [None] * num_agents
    agent_type = torch.zeros(num_agents, dtype=torch.uint8)
    agent_category = torch.zeros(num_agents, dtype=torch.uint8)
    position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
    heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
    velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
    shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)

    for track_id, track_df in df.groupby('track_id'):
        agent_idx = agent_ids.index(track_id)
        agent_steps = track_df['timestep'].values

        valid_mask[agent_idx, agent_steps] = True
        current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]
        predict_mask[agent_idx, agent_steps] = True
        if vector_repr:  # a time step t is valid only when both t and t-1 are valid
            valid_mask[agent_idx, 1: num_historical_steps] = (
                valid_mask[agent_idx, :num_historical_steps - 1] &
                valid_mask[agent_idx, 1: num_historical_steps])
            valid_mask[agent_idx, 0] = False
        predict_mask[agent_idx, :num_historical_steps] = False
        if not current_valid_mask[agent_idx]:
            predict_mask[agent_idx, num_historical_steps:] = False

        agent_id[agent_idx] = track_id
        agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
        agent_category[agent_idx] = track_df['object_category'].values[0]
        position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,
                                                                          track_df['position_y'].values,
                                                                          track_df['position_z'].values],
                                                                         axis=-1)).float()
        heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()
        velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,
                                                                          track_df['velocity_y'].values],
                                                                         axis=-1)).float()
        shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,
                                                                       track_df['width'].values,
                                                                       track_df["height"].values],
                                                                      axis=-1)).float()
    av_idx = agent_id.index(av_id)
    if split == 'test':
        predict_mask[current_valid_mask
                     | (agent_category == 2)
                     | (agent_category == 3), num_historical_steps:] = True

    return {
        'num_nodes': num_agents,
        'av_index': av_idx,
        'valid_mask': valid_mask,
        'predict_mask': predict_mask,
        'id': agent_id,
        'type': agent_type,
        'category': agent_category,
        'position': position,
        'heading': heading,
        'velocity': velocity,
        'shape': shape
    }


def get_map_features(map_infos, tf_current_light, dim=3):
    lane_segments = map_infos['lane']
    all_polylines = map_infos["all_polylines"]
    crosswalks = map_infos['crosswalk']
    road_edges = map_infos['road_edge']
    road_lines = map_infos['road_line']
    lane_segment_ids = [info["id"] for info in lane_segments]
    cross_walk_ids = [info["id"] for info in crosswalks]
    road_edge_ids = [info["id"] for info in road_edges]
    road_line_ids = [info["id"] for info in road_lines]
    polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids
    num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids)

    # initialization
    polygon_type = torch.zeros(num_polygons, dtype=torch.uint8)
    polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3

    point_position: List[Optional[torch.Tensor]] = [None] * num_polygons
    point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons
    point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons
    point_height: List[Optional[torch.Tensor]] = [None] * num_polygons
    point_type: List[Optional[torch.Tensor]] = [None] * num_polygons

    for lane_segment in lane_segments:
        lane_segment = easydict.EasyDict(lane_segment)
        lane_segment_idx = polygon_ids.index(lane_segment.id)
        polyline_index = lane_segment.polyline_index
        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
        centerline = torch.from_numpy(centerline).float()
        polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type])

        res = tf_current_light[tf_current_light["lane_id"] == str(lane_segment.id)]
        if len(res) != 0:
            polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res["state"].item())

        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
        center_vectors = centerline[1:] - centerline[:-1]
        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
        center_type = _point_types.index('CENTERLINE')
        point_type[lane_segment_idx] = torch.cat(
            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)

    for lane_segment in road_edges:
        lane_segment = easydict.EasyDict(lane_segment)
        lane_segment_idx = polygon_ids.index(lane_segment.id)
        polyline_index = lane_segment.polyline_index
        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
        centerline = torch.from_numpy(centerline).float()
        polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")

        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
        center_vectors = centerline[1:] - centerline[:-1]
        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
        center_type = _point_types.index('EDGE')
        point_type[lane_segment_idx] = torch.cat(
            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)

    for lane_segment in road_lines:
        lane_segment = easydict.EasyDict(lane_segment)
        lane_segment_idx = polygon_ids.index(lane_segment.id)
        polyline_index = lane_segment.polyline_index
        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
        centerline = torch.from_numpy(centerline).float()

        polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")

        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
        center_vectors = centerline[1:] - centerline[:-1]
        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
        center_type = _point_types.index(boundary_type_hash[lane_segment.type])
        point_type[lane_segment_idx] = torch.cat(
            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)

    for crosswalk in crosswalks:
        crosswalk = easydict.EasyDict(crosswalk)
        lane_segment_idx = polygon_ids.index(crosswalk.id)
        polyline_index = crosswalk.polyline_index
        centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
        centerline = torch.from_numpy(centerline).float()

        polygon_type[lane_segment_idx] = _polygon_types.index("PEDESTRIAN")

        point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
        center_vectors = centerline[1:] - centerline[:-1]
        point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
        point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
        point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
        center_type = _point_types.index("CROSSWALK")
        point_type[lane_segment_idx] = torch.cat(
            [torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)

    num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long)
    point_to_polygon_edge_index = torch.stack(
        [torch.arange(num_points.sum(), dtype=torch.long),
            torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0)
    polygon_to_polygon_edge_index = []
    polygon_to_polygon_type = []
    for lane_segment in lane_segments:
        lane_segment = easydict.EasyDict(lane_segment)
        lane_segment_idx = polygon_ids.index(lane_segment.id)
        pred_inds = []
        for pred in lane_segment.entry_lanes:
            pred_idx = safe_list_index(polygon_ids, pred)
            if pred_idx is not None:
                pred_inds.append(pred_idx)
        if len(pred_inds) != 0:
            polygon_to_polygon_edge_index.append(
                torch.stack([torch.tensor(pred_inds, dtype=torch.long),
                             torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
            polygon_to_polygon_type.append(
                torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8))
        succ_inds = []
        for succ in lane_segment.exit_lanes:
            succ_idx = safe_list_index(polygon_ids, succ)
            if succ_idx is not None:
                succ_inds.append(succ_idx)
        if len(succ_inds) != 0:
            polygon_to_polygon_edge_index.append(
                torch.stack([torch.tensor(succ_inds, dtype=torch.long),
                             torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
            polygon_to_polygon_type.append(
                torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8))
        if len(lane_segment.left_neighbors) != 0:
            left_neighbor_ids = lane_segment.left_neighbors
            for left_neighbor_id in left_neighbor_ids:
                left_idx = safe_list_index(polygon_ids, left_neighbor_id)
                if left_idx is not None:
                    polygon_to_polygon_edge_index.append(
                        torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long))
                    polygon_to_polygon_type.append(
                        torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8))
        if len(lane_segment.right_neighbors) != 0:
            right_neighbor_ids = lane_segment.right_neighbors
            for right_neighbor_id in right_neighbor_ids:
                right_idx = safe_list_index(polygon_ids, right_neighbor_id)
                if right_idx is not None:
                    polygon_to_polygon_edge_index.append(
                        torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long))
                    polygon_to_polygon_type.append(
                        torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8))
    if len(polygon_to_polygon_edge_index) != 0:
        polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1)
        polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0)
    else:
        polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long)
        polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8)

    map_data = {
        'map_polygon': {},
        'map_point': {},
        ('map_point', 'to', 'map_polygon'): {},
        ('map_polygon', 'to', 'map_polygon'): {},
    }
    map_data['map_polygon']['num_nodes'] = num_polygons
    map_data['map_polygon']['type'] = polygon_type
    map_data['map_polygon']['light_type'] = polygon_light_type
    if len(num_points) == 0:
        map_data['map_point']['num_nodes'] = 0
        map_data['map_point']['position'] = torch.tensor([], dtype=torch.float)
        map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float)
        map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float)
        if dim == 3:
            map_data['map_point']['height'] = torch.tensor([], dtype=torch.float)
        map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8)
        map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8)
    else:
        map_data['map_point']['num_nodes'] = num_points.sum().item()
        map_data['map_point']['position'] = torch.cat(point_position, dim=0)
        map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0)
        map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0)
        if dim == 3:
            map_data['map_point']['height'] = torch.cat(point_height, dim=0)
        map_data['map_point']['type'] = torch.cat(point_type, dim=0)
    map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index
    map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index
    map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type
    # import matplotlib.pyplot as plt
    # plt.axis('equal')
    # plt.scatter(map_data['map_point']['position'][:, 0],
    #             map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')
    # plt.show(dpi=600)
    return map_data


def process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, start_timestamp, end_timestamp):
    agents_array = track_info["trajs"].transpose(1, 0, 2)
    object_id = np.array(track_info["object_id"])
    object_type = track_info["object_type"]
    id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))}
    def type_hash(x):
        tp = id_hash[x]
        type_re_hash = {
            "TYPE_VEHICLE": "vehicle",
            "TYPE_PEDESTRIAN": "pedestrian",
            "TYPE_CYCLIST": "cyclist",
            "TYPE_OTHER": "background",
            "TYPE_UNSET": "background"
        }
        return type_re_hash[tp]

    columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep',
               'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y',
               'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps',
               'focal_track_id', 'city']
    new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11))
    new_columns[:11, :, 0] = True
    new_columns[11:, :, 0] = False
    for index in range(new_columns.shape[0]):
        new_columns[index, :, 4] = int(index)
    new_columns[..., 1] = object_id
    new_columns[..., 2] = object_id
    new_columns[:, tracks_to_predict["track_index"], 3] = 3
    new_columns[..., 5] = 11
    new_columns[..., 6] = int(start_timestamp)
    new_columns[..., 7] = int(end_timestamp)
    new_columns[..., 8] = int(91)
    new_columns[..., 9] = object_id
    new_columns[..., 10] = 10086
    new_columns = new_columns
    new_agents_array = np.concatenate([new_columns, agents_array], axis=-1)
    new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1])
    new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10]]
    new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns)
    new_agents_array["object_type"] = new_agents_array["object_type"].apply(func=type_hash)
    new_agents_array["start_timestamp"] = new_agents_array["start_timestamp"].astype(int)
    new_agents_array["end_timestamp"] = new_agents_array["end_timestamp"].astype(int)
    new_agents_array["num_timestamps"] = new_agents_array["num_timestamps"].astype(int)
    new_agents_array["scenario_id"] = scenario_id
    return new_agents_array


def process_dynamic_map(dynamic_map_infos):
    lane_ids = dynamic_map_infos["lane_id"]
    tf_lights = []
    for t in range(len(lane_ids)):
        lane_id = lane_ids[t]
        time = np.ones_like(lane_id) * t
        state = dynamic_map_infos["state"][t]
        tf_light = np.concatenate([lane_id, time, state], axis=0)
        tf_lights.append(tf_light)
    tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0)
    tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"])
    tf_lights["time_step"] = tf_lights["time_step"].astype("str")
    tf_lights["lane_id"] = tf_lights["lane_id"].astype("str")
    tf_lights["state"] = tf_lights["state"].astype("str")
    tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"] ] = 'LANE_STATE_STOP'
    tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"] ] = 'LANE_STATE_GO'
    tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"] ] = 'LANE_STATE_CAUTION'
    return tf_lights


polyline_type = {
    # for lane
    'TYPE_UNDEFINED': -1,
    'TYPE_FREEWAY': 1,
    'TYPE_SURFACE_STREET': 2,
    'TYPE_BIKE_LANE': 3,

    # for roadline
    'TYPE_UNKNOWN': -1,
    'TYPE_BROKEN_SINGLE_WHITE': 6,
    'TYPE_SOLID_SINGLE_WHITE': 7,
    'TYPE_SOLID_DOUBLE_WHITE': 8,
    'TYPE_BROKEN_SINGLE_YELLOW': 9,
    'TYPE_BROKEN_DOUBLE_YELLOW': 10,
    'TYPE_SOLID_SINGLE_YELLOW': 11,
    'TYPE_SOLID_DOUBLE_YELLOW': 12,
    'TYPE_PASSING_DOUBLE_YELLOW': 13,

    # for roadedge
    'TYPE_ROAD_EDGE_BOUNDARY': 15,
    'TYPE_ROAD_EDGE_MEDIAN': 16,

    # for stopsign
    'TYPE_STOP_SIGN': 17,

    # for crosswalk
    'TYPE_CROSSWALK': 18,

    # for speed bump
    'TYPE_SPEED_BUMP': 19
}

object_type = {
    0: 'TYPE_UNSET',
    1: 'TYPE_VEHICLE',
    2: 'TYPE_PEDESTRIAN',
    3: 'TYPE_CYCLIST',
    4: 'TYPE_OTHER'
}


signal_state = {
    0: 'LANE_STATE_UNKNOWN',

    # // States for traffic signals with arrows.
    1: 'LANE_STATE_ARROW_STOP',
    2: 'LANE_STATE_ARROW_CAUTION',
    3: 'LANE_STATE_ARROW_GO',

    # // Standard round traffic signals.
    4: 'LANE_STATE_STOP',
    5: 'LANE_STATE_CAUTION',
    6: 'LANE_STATE_GO',

    # // Flashing light signals.
    7: 'LANE_STATE_FLASHING_STOP',
    8: 'LANE_STATE_FLASHING_CAUTION'
}

signal_state_to_id = {}
for key, val in signal_state.items():
    signal_state_to_id[val] = key


def decode_tracks_from_proto(tracks):
    track_infos = {
        'object_id': [],  # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others}
        'object_type': [],
        'trajs': []
    }
    for cur_data in tracks:  # number of objects
        cur_traj = [np.array([x.center_x, x.center_y, x.center_z, x.length, x.width, x.height, x.heading,
                              x.velocity_x, x.velocity_y, x.valid], dtype=np.float32) for x in cur_data.states]
        cur_traj = np.stack(cur_traj, axis=0)  # (num_timestamp, 10)

        track_infos['object_id'].append(cur_data.id)
        track_infos['object_type'].append(object_type[cur_data.object_type])
        track_infos['trajs'].append(cur_traj)

    track_infos['trajs'] = np.stack(track_infos['trajs'], axis=0)  # (num_objects, num_timestamp, 9)
    return track_infos


from collections import defaultdict


def decode_map_features_from_proto(map_features):
    map_infos = {
        'lane': [],
        'road_line': [],
        'road_edge': [],
        'stop_sign': [],
        'crosswalk': [],
        'speed_bump': [],
        'lane_dict': {},
        'lane2other_dict': {}
    }
    polylines = []

    point_cnt = 0
    lane2other_dict = defaultdict(list)

    for cur_data in map_features:
        cur_info = {'id': cur_data.id}

        if cur_data.lane.ByteSize() > 0:
            cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph
            cur_info['type'] = cur_data.lane.type + 1  # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane
            cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors]

            cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors]

            cur_info['interpolating'] = cur_data.lane.interpolating
            cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes)
            cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes)

            cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries]
            cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries]

            cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries]
            cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries]
            cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries]
            cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries]
            cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries]
            cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries]

            lane2other_dict[cur_data.id].extend(cur_info['left_boundary'])
            lane2other_dict[cur_data.id].extend(cur_info['right_boundary'])

            global_type = cur_info['type']
            cur_polyline = np.stack(
                [np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline],
                axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['lane'].append(cur_info)
            map_infos['lane_dict'][cur_data.id] = cur_info

        elif cur_data.road_line.ByteSize() > 0:
            cur_info['type'] = cur_data.road_line.type + 5

            global_type = cur_info['type']
            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
                                     cur_data.road_line.polyline], axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['road_line'].append(cur_info)

        elif cur_data.road_edge.ByteSize() > 0:
            cur_info['type'] = cur_data.road_edge.type + 14

            global_type = cur_info['type']
            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
                                     cur_data.road_edge.polyline], axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['road_edge'].append(cur_info)

        elif cur_data.stop_sign.ByteSize() > 0:
            cur_info['lane_ids'] = list(cur_data.stop_sign.lane)
            for i in cur_info['lane_ids']:
                lane2other_dict[i].append(cur_data.id)
            point = cur_data.stop_sign.position
            cur_info['position'] = np.array([point.x, point.y, point.z])

            global_type = polyline_type['TYPE_STOP_SIGN']
            cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['stop_sign'].append(cur_info)
        elif cur_data.crosswalk.ByteSize() > 0:
            global_type = polyline_type['TYPE_CROSSWALK']
            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
                                     cur_data.crosswalk.polygon], axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['crosswalk'].append(cur_info)

        elif cur_data.speed_bump.ByteSize() > 0:
            global_type = polyline_type['TYPE_SPEED_BUMP']
            cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
                                     cur_data.speed_bump.polygon], axis=0)
            cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
            if cur_polyline.shape[0] <= 1:
                continue
            map_infos['speed_bump'].append(cur_info)

        else:
            # print(cur_data)
            continue
        polylines.append(cur_polyline)
        cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline))
        point_cnt += len(cur_polyline)

    # try:
    polylines = np.concatenate(polylines, axis=0).astype(np.float32)
    # except:
    #     polylines = np.zeros((0, 8), dtype=np.float32)
    #     print('Empty polylines: ')
    map_infos['all_polylines'] = polylines
    map_infos['lane2other_dict'] = lane2other_dict
    return map_infos


def decode_dynamic_map_states_from_proto(dynamic_map_states):
    dynamic_map_infos = {
        'lane_id': [],
        'state': [],
        'stop_point': []
    }
    for cur_data in dynamic_map_states:  # (num_timestamp)
        lane_id, state, stop_point = [], [], []
        for cur_signal in cur_data.lane_states:  # (num_observed_signals)
            lane_id.append(cur_signal.lane)
            state.append(signal_state[cur_signal.state])
            stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z])

        dynamic_map_infos['lane_id'].append(np.array([lane_id]))
        dynamic_map_infos['state'].append(np.array([state]))
        dynamic_map_infos['stop_point'].append(np.array([stop_point]))

    return dynamic_map_infos


def process_single_data(scenario):
    info = {}
    info['scenario_id'] = scenario.scenario_id
    info['timestamps_seconds'] = list(scenario.timestamps_seconds)  # list of int of shape (91)
    info['current_time_index'] = scenario.current_time_index  # int, 10
    info['sdc_track_index'] = scenario.sdc_track_index  # int
    info['objects_of_interest'] = list(scenario.objects_of_interest)  # list, could be empty list

    info['tracks_to_predict'] = {
        'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict],
        'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict]
    }  # for training: suggestion of objects to train on, for val/test: need to be predicted

    track_infos = decode_tracks_from_proto(scenario.tracks)
    info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in
                                                info['tracks_to_predict']['track_index']]

    # decode map related data
    map_infos = decode_map_features_from_proto(scenario.map_features)
    dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states)

    save_infos = {
        'track_infos': track_infos,
        'dynamic_map_infos': dynamic_map_infos,
        'map_infos': map_infos
    }
    save_infos.update(info)
    return save_infos

import tensorflow as tf
from waymo_open_dataset.protos import scenario_pb2


def wm2argo(file, dir_name, output_dir):
    file_path = os.path.join(dir_name, file)
    dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3)
    for cnt, data in enumerate(dataset):
        print(cnt)
        scenario = scenario_pb2.Scenario()
        scenario.ParseFromString(bytearray(data.numpy()))
        save_infos = process_single_data(scenario) # pkl2mtr
        map_info = save_infos["map_infos"]
        track_info = save_infos['track_infos']
        scenario_id = save_infos['scenario_id']
        tracks_to_predict = save_infos['tracks_to_predict']
        sdc_track_index = save_infos['sdc_track_index']
        av_id = track_info["object_id"][sdc_track_index]
        if len(tracks_to_predict["track_index"]) < 1:
            return
        dynamic_map_infos = save_infos["dynamic_map_infos"]
        tf_lights = process_dynamic_map(dynamic_map_infos)
        tf_current_light = tf_lights.loc[tf_lights["time_step"] == "11"]
        map_data = get_map_features(map_info, tf_current_light)
        new_agents_array = process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, 0, 91) # mtr2argo
        data = dict()
        data['scenario_id'] = new_agents_array['scenario_id'].values[0]
        data['city'] = new_agents_array['city'].values[0]
        data['agent'] = get_agent_features(new_agents_array, av_id, num_historical_steps=11)
        data.update(map_data)
        with open(os.path.join(output_dir, scenario_id + '.pkl'), "wb+") as f:
            pickle.dump(data, f)


def batch_process9s_transformer(dir_name, output_dir, num_workers=2):
    from functools import partial
    import multiprocessing
    packages = os.listdir(dir_name)
    func = partial(
        wm2argo, output_dir=output_dir, dir_name=dir_name)
    with multiprocessing.Pool(num_workers) as p:
        list(tqdm(p.imap(func, packages), total=len(packages)))


from argparse import ArgumentParser


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--input_dir', type=str, default='data/waymo/scenario/training')
    parser.add_argument('--output_dir', type=str, default='data/waymo_processed/training')
    args = parser.parse_args()
    files = os.listdir(args.input_dir)
    for file in tqdm(files):
        wm2argo(file, args.input_dir, args.output_dir)
    # batch_process9s_transformer(args.input_dir, args.output_dir, num_workers="ur_cpu_count")


================================================
FILE: environment.yml
================================================
name: smart
channels:
  - pytorch
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - blas=1.0=mkl
  - brotli-python=1.0.9=py39h6a678d5_8
  - bzip2=1.0.8=h5eee18b_6
  - ca-certificates=2024.9.24=h06a4308_0
  - certifi=2024.8.30=py39h06a4308_0
  - charset-normalizer=3.3.2=pyhd3eb1b0_0
  - cudatoolkit=11.3.1=h2bc3f7f_2
  - ffmpeg=4.3=hf484d3e_0
  - freetype=2.12.1=h4a9f257_0
  - gmp=6.2.1=h295c915_3
  - gnutls=3.6.15=he1e5248_0
  - idna=3.7=py39h06a4308_0
  - intel-openmp=2023.1.0=hdb19cb5_46306
  - jpeg=9e=h5eee18b_3
  - lame=3.100=h7b6447c_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.40=h12ee557_0
  - lerc=3.0=h295c915_0
  - libdeflate=1.17=h5eee18b_1
  - libffi=3.4.4=h6a678d5_1
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libiconv=1.14=0
  - libidn2=2.3.4=h5eee18b_0
  - libpng=1.6.39=h5eee18b_0
  - libstdcxx-ng=11.2.0=h1234567_1
  - libtasn1=4.19.0=h5eee18b_0
  - libtiff=4.5.1=h6a678d5_0
  - libunistring=0.9.10=h27cfd23_0
  - libwebp-base=1.3.2=h5eee18b_1
  - lz4-c=1.9.4=h6a678d5_1
  - mkl=2023.1.0=h213fc3f_46344
  - mkl-service=2.4.0=py39h5eee18b_1
  - mkl_fft=1.3.10=py39h5eee18b_0
  - mkl_random=1.2.7=py39h1128e8f_0
  - ncurses=6.4=h6a678d5_0
  - nettle=3.7.3=hbbd107a_1
  - openh264=2.1.1=h4ff587b_0
  - openjpeg=2.5.2=he7f1fd0_0
  - openssl=3.0.15=h5eee18b_0
  - pillow=10.4.0=py39h5eee18b_0
  - pip=24.2=py39h06a4308_0
  - pysocks=1.7.1=py39h06a4308_0
  - python=3.9.19=h955ad1f_1
  - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0
  - pytorch-mutex=1.0=cuda
  - readline=8.2=h5eee18b_0
  - requests=2.32.3=py39h06a4308_0
  - setuptools=75.1.0=py39h06a4308_0
  - sqlite=3.45.3=h5eee18b_0
  - tbb=2021.8.0=hdb19cb5_0
  - tk=8.6.14=h39e8969_0
  - torchvision=0.13.1=py39_cu113
  - typing_extensions=4.11.0=py39h06a4308_0
  - urllib3=2.2.3=py39h06a4308_0
  - wheel=0.44.0=py39h06a4308_0
  - xz=5.4.6=h5eee18b_1
  - zlib=1.2.13=h5eee18b_1
  - zstd=1.5.6=hc292b87_0


================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "smart"
version = "0.0.0"
description = "Scalable Multi-agent Real-time Motion Generation via Next-token Prediction"
readme = "README.md"
authors = [
    {name = "Xiaoxin Feng"},
    {name = "Ziyan Gao"},
    {name = "Yuheng Kan"}
]
classifiers = [
    "Programming Language :: Python :: 3",
    "License :: OSI Approved :: Apache Software License",
    "Operating System :: OS Independent",
]
requires-python = ">=3.9"
dependencies = [
    "easydict",
    "numpy",
    "pandas",
    "pytorch-lightning",
    "scipy",
    "torch-cluster",
    "torch-geometric",
    "torch-scatter",
    "torch",
    "torchmetrics",
    "tqdm",
]

[project.urls]
"Homepage" = "https://smart-motion.github.io/smart/"
"Repository" = "https://github.com/rainmaker22/SMART"
"Paper" = "https://arxiv.org/abs/2405.15677"

[tool.setuptools]
packages = ["smart"]


================================================
FILE: requirements.txt
================================================
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
async-timeout==4.0.3
attrs==24.2.0
contourpy==1.3.0
cycler==0.12.1
easydict==1.13
fonttools==4.54.1
frozenlist==1.4.1
fsspec==2024.10.0
importlib-resources==6.4.5
jinja2==3.1.4
kiwisolver==1.4.7
lightning-utilities==0.11.8
markupsafe==3.0.2
matplotlib==3.9.2
multidict==6.1.0
numpy==1.26.4
packaging==24.1
pandas==2.0.3
propcache==0.2.0
psutil==6.1.0
pyparsing==3.2.0
python-dateutil==2.9.0.post0
pytorch-lightning==2.0.3
pytz==2024.2
pyyaml==6.0.1
scipy==1.10.1
shapely==2.0.6
six==1.16.0
torch-cluster==1.6.0+pt112cu113
torch-geometric==2.6.1
torch-scatter==2.1.0+pt112cu113
torch-sparse==0.6.16+pt112cu113
torch-spline-conv==1.2.1+pt112cu113
torchmetrics==1.5.0
tqdm==4.66.5
tzdata==2024.2
yarl==1.16.0
zipp==3.20.2
waymo-open-dataset-tf-2-12-0==1.6.4


================================================
FILE: scripts/install_pyg.sh
================================================
mkdir pyg_depend && cd pyg_depend
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.16%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_cluster-1.6.0+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_scatter-2.1.0+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_sparse-0.6.16+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_spline_conv-1.2.1+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_geometric


================================================
FILE: scripts/traj_clstering.py
================================================
from smart.utils.geometry import wrap_angle
import numpy as np


def average_distance_vectorized(point_set1, centroids):
    dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :])**2, axis=-1))
    return np.mean(dists, axis=2)


def assign_clusters(sub_X, centroids):
    distances = average_distance_vectorized(sub_X, centroids)
    return np.argmin(distances, axis=1)


def Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None):
    S = []
    ret_traj_list = []
    while len(S) < N:
        num_all = X.shape[0]
        # 随机选择第一个簇中心
        choice_index = np.random.choice(num_all)
        x0 = X[choice_index]
        if x0[0, 0] < -10 or x0[0, 0] > 50 or x0[0, 1] > 10 or x0[0, 1] < -10:
            continue
        res_mask = np.sum((X - x0)**2, axis=(1, 2))/4 > (tol**2)
        del_mask = np.sum((X - x0)**2, axis=(1, 2))/4 <= (tol**2)
        if cal_mean_heading:
            del_contour = X[del_mask]
            diff_xy = del_contour[:, 0, :] - del_contour[:, 3, :]
            del_heading = np.arctan2(diff_xy[:, 1], diff_xy[:, 0]).mean()
            x0 = cal_polygon_contour(x0.mean(0)[0], x0.mean(0)[1], del_heading, width, length)
            del_traj = a_pos[del_mask]
            ret_traj = del_traj.mean(0)[None, ...]
            if abs(ret_traj[0, 1, 0] - ret_traj[0, 0, 0]) > 1 and ret_traj[0, 1, 0] < 0:
                print(ret_traj)
                print('1')
        else:
            x0 = x0[None, ...]
            ret_traj = a_pos[choice_index][None, ...]
        X = X[res_mask]
        a_pos = a_pos[res_mask]
        S.append(x0)
        ret_traj_list.append(ret_traj)
    centroids = np.concatenate(S, axis=0)
    ret_traj = np.concatenate(ret_traj_list, axis=0)

    # closest_dist_sq = np.sum((X - centroids[0])**2, axis=(1, 2))

    # for k in range(1, K):
    #     new_dist_sq = np.sum((X - centroids[k - 1])**2, axis=(1, 2))
    #     closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq)
    #     probabilities = closest_dist_sq / np.sum(closest_dist_sq)
    #     centroids[k] = X[np.random.choice(N, p=probabilities)]

    return centroids, ret_traj


def cal_polygon_contour(x, y, theta, width, length):

    left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
    left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
    left_front = np.column_stack((left_front_x, left_front_y))

    right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
    right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
    right_front = np.column_stack((right_front_x, right_front_y))

    right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
    right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
    right_back = np.column_stack((right_back_x, right_back_y))

    left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
    left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
    left_back = np.column_stack((left_back_x, left_back_y))

    polygon_contour = np.concatenate((left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)

    return polygon_contour


if __name__ == '__main__':
    shift = 5 # motion token time dimension
    num_cluster = 6 # vocabulary size
    cal_mean_heading = True
    data = {
        "veh": np.random.rand(1000, 6, 3),
        "cyc": np.random.rand(1000, 6, 3),
        "ped": np.random.rand(1000, 6, 3)
    }
    # Collect the trajectories of all traffic participants from the raw data [NumAgent, shift+1, [relative_x, relative_y, relative_theta]]
    nms_res = {}
    res = {'token': {}, 'traj': {}, 'token_all': {}}
    for k, v in data.items():
        # if k != 'veh':
        #     continue
        a_pos = v
        print(a_pos.shape)
        # a_pos = a_pos[:, shift:1+shift, :]
        cal_num = min(int(1e6), a_pos.shape[0])
        a_pos = a_pos[np.random.choice(a_pos.shape[0], cal_num, replace=False)]
        a_pos[:, :, -1] = wrap_angle(a_pos[:, :, -1])
        print(a_pos.shape)
        if shift <= 2:
            if k == 'veh':
                width = 1.0
                length = 2.4
            elif k == 'cyc':
                width = 0.5
                length = 1.5
            else:
                width = 0.5
                length = 0.5
        else:
            if k == 'veh':
                width = 2.0
                length = 4.8
            elif k == 'cyc':
                width = 1.0
                length = 2.0
            else:
                width = 1.0
                length = 1.0
        contour = cal_polygon_contour(a_pos[:, shift, 0], a_pos[:, shift, 1], a_pos[:, shift, 2], width, length)

        # plt.figure(figsize=(10, 10))
        # for rect in contour:
        #     rect_closed = np.vstack([rect, rect[0]])
        #     plt.plot(rect_closed[:, 0], rect_closed[:, 1], linewidth=0.1)

        # plt.title("Plot of 256 Rectangles")
        # plt.xlabel("x")
        # plt.ylabel("y")
        # plt.axis('equal')
        # plt.savefig(f'src_{k}_new.jpg', dpi=300)

        if k == 'veh':
            tol = 0.05
        elif k == 'cyc':
            tol = 0.004
        else:
            tol = 0.004
        centroids, ret_traj = Kdisk_cluster(contour, num_cluster, tol, width, length, a_pos[:, :shift+1])
        # plt.figure(figsize=(10, 10))
        contour = cal_polygon_contour(ret_traj[:, :, 0].reshape(num_cluster*(shift+1)),
                                      ret_traj[:, :, 1].reshape(num_cluster*(shift+1)),
                                      ret_traj[:, :, 2].reshape(num_cluster*(shift+1)), width, length)

        res['token_all'][k] = contour.reshape(num_cluster, (shift+1), 4, 2)
        res['token'][k] = centroids
        res['traj'][k] = ret_traj


================================================
FILE: smart/__init__.py
================================================


================================================
FILE: smart/datamodules/__init__.py
================================================
from smart.datamodules.scalable_datamodule import MultiDataModule


================================================
FILE: smart/datamodules/scalable_datamodule.py
================================================
from typing import Optional

import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
from smart.datasets.scalable_dataset import MultiDataset
from smart.transforms import WaymoTargetBuilder


class MultiDataModule(pl.LightningDataModule):
    transforms = {
        "WaymoTargetBuilder": WaymoTargetBuilder,
    }

    dataset = {
        "scalable": MultiDataset,
    }

    def __init__(self,
                 root: str,
                 train_batch_size: int,
                 val_batch_size: int,
                 test_batch_size: int,
                 shuffle: bool = False,
                 num_workers: int = 0,
                 pin_memory: bool = True,
                 persistent_workers: bool = True,
                 train_raw_dir: Optional[str] = None,
                 val_raw_dir: Optional[str] = None,
                 test_raw_dir: Optional[str] = None,
                 train_processed_dir: Optional[str] = None,
                 val_processed_dir: Optional[str] = None,
                 test_processed_dir: Optional[str] = None,
                 transform: Optional[str] = None,
                 dataset: Optional[str] = None,
                 num_historical_steps: int = 50,
                 num_future_steps: int = 60,
                 processor='ntp',
                 use_intention=False,
                 token_size=512,
                 **kwargs) -> None:
        super(MultiDataModule, self).__init__()
        self.root = root
        self.dataset_class = dataset
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.test_batch_size = test_batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers and num_workers > 0
        self.train_raw_dir = train_raw_dir
        self.val_raw_dir = val_raw_dir
        self.test_raw_dir = test_raw_dir
        self.train_processed_dir = train_processed_dir
        self.val_processed_dir = val_processed_dir
        self.test_processed_dir = test_processed_dir
        self.processor = processor
        self.use_intention = use_intention
        self.token_size = token_size

        train_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "train")
        val_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "val")
        test_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps)

        self.train_transform = train_transform
        self.val_transform = val_transform
        self.test_transform = test_transform

    def setup(self, stage: Optional[str] = None) -> None:
        self.train_dataset = MultiDataModule.dataset[self.dataset_class](self.root, 'train', processed_dir=self.train_processed_dir,
                                                                         raw_dir=self.train_raw_dir, processor=self.processor, transform=self.train_transform, token_size=self.token_size)
        self.val_dataset = MultiDataModule.dataset[self.dataset_class](None, 'val', processed_dir=self.val_processed_dir,
                                                                       raw_dir=self.val_raw_dir, processor=self.processor, transform=self.val_transform, token_size=self.token_size)
        self.test_dataset = MultiDataModule.dataset[self.dataset_class](None, 'test', processed_dir=self.test_processed_dir,
                                                                        raw_dir=self.test_raw_dir, processor=self.processor, transform=self.test_transform, token_size=self.token_size)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle,
                          num_workers=self.num_workers, pin_memory=self.pin_memory,
                          persistent_workers=self.persistent_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=self.pin_memory,
                          persistent_workers=self.persistent_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=self.pin_memory,
                          persistent_workers=self.persistent_workers)


================================================
FILE: smart/datasets/__init__.py
================================================
from smart.datasets.scalable_dataset import MultiDataset


================================================
FILE: smart/datasets/preprocess.py
================================================
import torch
import numpy as np
from scipy.interpolate import interp1d
from scipy.spatial.distance import euclidean
import math
import pickle
from smart.utils import wrap_angle
import os

def cal_polygon_contour(x, y, theta, width, length):
    left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
    left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
    left_front = np.column_stack((left_front_x, left_front_y))

    right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
    right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
    right_front = np.column_stack((right_front_x, right_front_y))

    right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
    right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
    right_back = np.column_stack((right_back_x, right_back_y))

    left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
    left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
    left_back = np.column_stack((left_back_x, left_back_y))

    polygon_contour = np.concatenate(
        (left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)

    return polygon_contour


def interplating_polyline(polylines, heading, distance=0.5, split_distace=5):
    # Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter
    dist_along_path_list = [[0]]
    polylines_list = [[polylines[0]]]
    for i in range(1, polylines.shape[0]):
        euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2])
        heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
                           abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi))
        if heading_diff > math.pi / 4 and euclidean_dist > 3:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        elif heading_diff > math.pi / 8 and euclidean_dist > 3:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        elif heading_diff > 0.1 and euclidean_dist > 3:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        elif euclidean_dist > 10:
            dist_along_path_list.append([0])
            polylines_list.append([polylines[i]])
        else:
            dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
            polylines_list[-1].append(polylines[i])
    # plt.plot(polylines[:, 0], polylines[:, 1])
    # plt.savefig('tmp.jpg')
    new_x_list = []
    new_y_list = []
    multi_polylines_list = []
    for idx in range(len(dist_along_path_list)):
        if len(dist_along_path_list[idx]) < 2:
            continue
        dist_along_path = np.array(dist_along_path_list[idx])
        polylines_cur = np.array(polylines_list[idx])
        # Create interpolation functions for x and y coordinates
        fx = interp1d(dist_along_path, polylines_cur[:, 0])
        fy = interp1d(dist_along_path, polylines_cur[:, 1])
        # fyaw = interp1d(dist_along_path, heading)

        # Create an array of distances at which to interpolate
        new_dist_along_path = np.arange(0, dist_along_path[-1], distance)
        new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]])
        # Use the interpolation functions to generate new x and y coordinates
        new_x = fx(new_dist_along_path)
        new_y = fy(new_dist_along_path)
        # new_yaw = fyaw(new_dist_along_path)
        new_x_list.append(new_x)
        new_y_list.append(new_y)

        # Combine the new x and y coordinates into a single array
        new_polylines = np.vstack((new_x, new_y)).T
        polyline_size = int(split_distace / distance)
        if new_polylines.shape[0] >= (polyline_size + 1):
            padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
            final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
        else:
            padding_size = new_polylines.shape[0]
            final_index = 0
        multi_polylines = None
        new_polylines = torch.from_numpy(new_polylines)
        new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
                                  new_polylines[1:, 0] - new_polylines[:-1, 0])
        new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
        new_polylines = torch.cat([new_polylines, new_heading], -1)
        if new_polylines.shape[0] >= (polyline_size + 1):
            multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
            multi_polylines = multi_polylines.transpose(1, 2)
            multi_polylines = multi_polylines[:, ::5, :]
        if padding_size >= 3:
            last_polyline = new_polylines[final_index * polyline_size:]
            last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
            if multi_polylines is not None:
                multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
            else:
                multi_polylines = last_polyline.unsqueeze(0)
        if multi_polylines is None:
            continue
        multi_polylines_list.append(multi_polylines)
    if len(multi_polylines_list) > 0:
        multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
    else:
        multi_polylines_list = None
    return multi_polylines_list


def average_distance_vectorized(point_set1, centroids):
    dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1))
    return np.mean(dists, axis=2)


def assign_clusters(sub_X, centroids):
    distances = average_distance_vectorized(sub_X, centroids)
    return np.argmin(distances, axis=1)


class TokenProcessor:

    def __init__(self, token_size):
        module_dir = os.path.dirname(os.path.dirname(__file__))
        self.agent_token_path = os.path.join(module_dir, f'tokens/cluster_frame_5_{token_size}.pkl')
        self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
        self.noise = False
        self.disturb = False
        self.shift = 5
        self.get_trajectory_token()
        self.training = False
        self.current_step = 10

    def preprocess(self, data):
        data = self.tokenize_agent(data)
        data = self.tokenize_map(data)
        del data['city']
        if 'polygon_is_intersection' in data['map_polygon']:
            del data['map_polygon']['polygon_is_intersection']
        if 'route_type' in data['map_polygon']:
            del data['map_polygon']['route_type']
        return data

    def get_trajectory_token(self):
        agent_token_data = pickle.load(open(self.agent_token_path, 'rb'))
        map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
        self.trajectory_token = agent_token_data['token']
        self.trajectory_token_all = agent_token_data['token_all']
        self.map_token = {'traj_src': map_token_traj['traj_src'], }
        self.token_last = {}
        for k, v in self.trajectory_token_all.items():
            token_last = torch.from_numpy(v[:, -2:]).to(torch.float)
            diff_xy = token_last[:, 0, 0] - token_last[:, 0, 3]
            theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
            cos, sin = theta.cos(), theta.sin()
            rot_mat = theta.new_zeros(token_last.shape[0], 2, 2)
            rot_mat[:, 0, 0] = cos
            rot_mat[:, 0, 1] = -sin
            rot_mat[:, 1, 0] = sin
            rot_mat[:, 1, 1] = cos
            agent_token = torch.bmm(token_last[:, 1], rot_mat)
            agent_token -= token_last[:, 0].mean(1)[:, None, :]
            self.token_last[k] = agent_token.numpy()

    def clean_heading(self, data):
        heading = data['agent']['heading']
        valid = data['agent']['valid_mask']
        pi = torch.tensor(torch.pi)
        n_vehicles, n_frames = heading.shape

        heading_diff_raw = heading[:, :-1] - heading[:, 1:]
        heading_diff = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
        heading_diff[heading_diff > pi] -= 2 * pi
        heading_diff[heading_diff < -pi] += 2 * pi

        valid_pairs = valid[:, :-1] & valid[:, 1:]

        for i in range(n_frames - 1):
            change_needed = (torch.abs(heading_diff[:, i:i + 1]) > 1.0) & valid_pairs[:, i:i + 1]

            heading[:, i + 1][change_needed.squeeze()] = heading[:, i][change_needed.squeeze()]

            if i < n_frames - 2:
                heading_diff_raw = heading[:, i + 1] - heading[:, i + 2]
                heading_diff[:, i + 1] = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
                heading_diff[heading_diff[:, i + 1] > pi] -= 2 * pi
                heading_diff[heading_diff[:, i + 1] < -pi] += 2 * pi

    def tokenize_agent(self, data):
        if data['agent']["velocity"].shape[1] == 90:
            print(data['scenario_id'], data['agent']["velocity"].shape)
        interplote_mask = (data['agent']['valid_mask'][:, self.current_step] == False) * (
                data['agent']['position'][:, self.current_step, 0] != 0)
        if data['agent']["velocity"].shape[-1] == 2:
            data['agent']["velocity"] = torch.cat([data['agent']["velocity"],
                                                   torch.zeros(data['agent']["velocity"].shape[0],
                                                               data['agent']["velocity"].shape[1], 1)], dim=-1)
        vel = data['agent']["velocity"][interplote_mask, self.current_step]
        data['agent']['position'][interplote_mask, self.current_step - 1, :3] = data['agent']['position'][
                                                                                interplote_mask, self.current_step,
                                                                                :3] - vel * 0.1
        data['agent']['valid_mask'][interplote_mask, self.current_step - 1:self.current_step + 1] = True
        data['agent']['heading'][interplote_mask, self.current_step - 1] = data['agent']['heading'][
            interplote_mask, self.current_step]
        data['agent']["velocity"][interplote_mask, self.current_step - 1] = data['agent']["velocity"][
            interplote_mask, self.current_step]

        data['agent']['type'] = data['agent']['type'].to(torch.uint8)

        self.clean_heading(data)
        matching_extra_mask = (data['agent']['valid_mask'][:, self.current_step] == True) * (
                data['agent']['valid_mask'][:, self.current_step - 5] == False)

        interplote_mask_first = (data['agent']['valid_mask'][:, 0] == False) * (data['agent']['position'][:, 0, 0] != 0)
        data['agent']['valid_mask'][interplote_mask_first, 0] = True

        agent_pos = data['agent']['position'][:, :, :2]
        valid_mask = data['agent']['valid_mask']

        valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift)
        token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1]
        agent_type = data['agent']['type']
        agent_category = data['agent']['category']
        agent_heading = data['agent']['heading']
        vehicle_mask = agent_type == 0
        cyclist_mask = agent_type == 2
        ped_mask = agent_type == 1

        veh_pos = agent_pos[vehicle_mask, :, :]
        veh_valid_mask = valid_mask[vehicle_mask, :]
        cyc_pos = agent_pos[cyclist_mask, :, :]
        cyc_valid_mask = valid_mask[cyclist_mask, :]
        ped_pos = agent_pos[ped_mask, :, :]
        ped_valid_mask = valid_mask[ped_mask, :]

        veh_token_index, veh_token_contour = self.match_token(veh_pos, veh_valid_mask, agent_heading[vehicle_mask],
                                                              'veh', agent_category[vehicle_mask],
                                                              matching_extra_mask[vehicle_mask])
        ped_token_index, ped_token_contour = self.match_token(ped_pos, ped_valid_mask, agent_heading[ped_mask], 'ped',
                                                              agent_category[ped_mask], matching_extra_mask[ped_mask])
        cyc_token_index, cyc_token_contour = self.match_token(cyc_pos, cyc_valid_mask, agent_heading[cyclist_mask],
                                                              'cyc', agent_category[cyclist_mask],
                                                              matching_extra_mask[cyclist_mask])

        token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64)
        token_index[vehicle_mask] = veh_token_index
        token_index[ped_mask] = ped_token_index
        token_index[cyclist_mask] = cyc_token_index

        token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1],
                                     veh_token_contour.shape[2], veh_token_contour.shape[3]))
        token_contour[vehicle_mask] = veh_token_contour
        token_contour[ped_mask] = ped_token_contour
        token_contour[cyclist_mask] = cyc_token_contour

        trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(torch.float)
        trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(torch.float)
        trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(torch.float)

        agent_token_traj = torch.zeros((agent_pos.shape[0], trajectory_token_veh.shape[0], 4, 2))
        agent_token_traj[vehicle_mask] = trajectory_token_veh
        agent_token_traj[ped_mask] = trajectory_token_ped
        agent_token_traj[cyclist_mask] = trajectory_token_cyc

        if not self.training:
            token_valid_mask[matching_extra_mask, 1] = True

        data['agent']['token_idx'] = token_index
        data['agent']['token_contour'] = token_contour
        token_pos = token_contour.mean(dim=2)
        data['agent']['token_pos'] = token_pos
        diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :]
        data['agent']['token_heading'] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
        data['agent']['agent_valid_mask'] = token_valid_mask

        vel = torch.cat([token_pos.new_zeros(data['agent']['num_nodes'], 1, 2),
                         ((token_pos[:, 1:] - token_pos[:, :-1]) / (0.1 * self.shift))], dim=1)
        vel_valid_mask = torch.cat([torch.zeros(token_valid_mask.shape[0], 1, dtype=torch.bool),
                                    (token_valid_mask * token_valid_mask.roll(shifts=1, dims=1))[:, 1:]], dim=1)
        vel[~vel_valid_mask] = 0
        vel[data['agent']['valid_mask'][:, self.current_step], 1] = data['agent']['velocity'][
                                                                    data['agent']['valid_mask'][:, self.current_step],
                                                                    self.current_step, :2]

        data['agent']['token_velocity'] = vel

        return data

    def match_token(self, pos, valid_mask, heading, category, agent_category, extra_mask):
        agent_token_src = self.trajectory_token[category]
        token_last = self.token_last[category]
        if self.shift <= 2:
            if category == 'veh':
                width = 1.0
                length = 2.4
            elif category == 'cyc':
                width = 0.5
                length = 1.5
            else:
                width = 0.5
                length = 0.5
        else:
            if category == 'veh':
                width = 2.0
                length = 4.8
            elif category == 'cyc':
                width = 1.0
                length = 2.0
            else:
                width = 1.0
                length = 1.0

        prev_heading = heading[:, 0]
        prev_pos = pos[:, 0]
        agent_num, num_step, feat_dim = pos.shape
        token_num, token_contour_dim, feat_dim = agent_token_src.shape
        agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0)
        token_last = token_last.reshape(1, token_num * token_contour_dim, feat_dim).repeat(extra_mask.sum(), 0)
        token_index_list = []
        token_contour_list = []
        prev_token_idx = None

        for i in range(self.shift, pos.shape[1], self.shift):
            theta = prev_heading
            cur_heading = heading[:, i]
            cur_pos = pos[:, i]
            cos, sin = theta.cos(), theta.sin()
            rot_mat = theta.new_zeros(agent_num, 2, 2)
            rot_mat[:, 0, 0] = cos
            rot_mat[:, 0, 1] = sin
            rot_mat[:, 1, 0] = -sin
            rot_mat[:, 1, 1] = cos
            agent_token_world = torch.bmm(torch.from_numpy(agent_token_src).to(torch.float), rot_mat).reshape(agent_num,
                                                                                                              token_num,
                                                                                                              token_contour_dim,
                                                                                                              feat_dim)
            agent_token_world += prev_pos[:, None, None, :]

            cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
            agent_token_index = torch.from_numpy(np.argmin(
                np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
                axis=-1))
            if prev_token_idx is not None and self.noise:
                same_idx = prev_token_idx == agent_token_index
                same_idx[:] = True
                topk_indices = np.argsort(
                    np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)),
                            axis=2), axis=-1)[:, :5]
                sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0])
                agent_token_index[same_idx] = \
                    torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx]

            token_contour_select = agent_token_world[torch.arange(agent_num), agent_token_index]

            diff_xy = token_contour_select[:, 0, :] - token_contour_select[:, 3, :]

            prev_heading = heading[:, i].clone()
            prev_heading[valid_mask[:, i - self.shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[
                valid_mask[:, i - self.shift]]

            prev_pos = pos[:, i].clone()
            prev_pos[valid_mask[:, i - self.shift]] = token_contour_select.mean(dim=1)[valid_mask[:, i - self.shift]]
            prev_token_idx = agent_token_index
            token_index_list.append(agent_token_index[:, None])
            token_contour_list.append(token_contour_select[:, None, ...])

        token_index = torch.cat(token_index_list, dim=1)
        token_contour = torch.cat(token_contour_list, dim=1)

        # extra matching
        if not self.training:
            theta = heading[extra_mask, self.current_step - 1]
            prev_pos = pos[extra_mask, self.current_step - 1]
            cur_pos = pos[extra_mask, self.current_step]
            cur_heading = heading[extra_mask, self.current_step]
            cos, sin = theta.cos(), theta.sin()
            rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2)
            rot_mat[:, 0, 0] = cos
            rot_mat[:, 0, 1] = sin
            rot_mat[:, 1, 0] = -sin
            rot_mat[:, 1, 1] = cos
            agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape(
                extra_mask.sum(), token_num, token_contour_dim, feat_dim)
            agent_token_world += prev_pos[:, None, None, :]

            cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
            agent_token_index = torch.from_numpy(np.argmin(
                np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
                axis=-1))
            token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index]

            token_index[extra_mask, 1] = agent_token_index
            token_contour[extra_mask, 1] = token_contour_select

        return token_index, token_contour

    def tokenize_map(self, data):
        data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8)
        data['map_point']['type'] = data['map_point']['type'].to(torch.uint8)
        pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index']
        pt_type = data['map_point']['type'].to(torch.uint8)
        pt_side = torch.zeros_like(pt_type)
        pt_pos = data['map_point']['position'][:, :2]
        data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation'])
        pt_heading = data['map_point']['orientation']
        split_polyline_type = []
        split_polyline_pos = []
        split_polyline_theta = []
        split_polyline_side = []
        pl_idx_list = []
        split_polygon_type = []
        data['map_point']['type'].unique()

        for i in sorted(np.unique(pt2pl[1])):
            index = pt2pl[0, pt2pl[1] == i]
            polygon_type = data['map_polygon']["type"][i]
            cur_side = pt_side[index]
            cur_type = pt_type[index]
            cur_pos = pt_pos[index]
            cur_heading = pt_heading[index]

            for side_val in np.unique(cur_side):
                for type_val in np.unique(cur_type):
                    if type_val == 13:
                        continue
                    indices = np.where((cur_side == side_val) & (cur_type == type_val))[0]
                    if len(indices) <= 2:
                        continue
                    split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy())
                    if split_polyline is None:
                        continue
                    new_cur_type = cur_type[indices][0]
                    new_cur_side = cur_side[indices][0]
                    map_polygon_type = polygon_type.repeat(split_polyline.shape[0])
                    new_cur_type = new_cur_type.repeat(split_polyline.shape[0])
                    new_cur_side = new_cur_side.repeat(split_polyline.shape[0])
                    cur_pl_idx = torch.Tensor([i])
                    new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0])
                    split_polyline_pos.append(split_polyline[..., :2])
                    split_polyline_theta.append(split_polyline[..., 2])
                    split_polyline_type.append(new_cur_type)
                    split_polyline_side.append(new_cur_side)
                    pl_idx_list.append(new_cur_pl_idx)
                    split_polygon_type.append(map_polygon_type)

        split_polyline_pos = torch.cat(split_polyline_pos, dim=0)
        split_polyline_theta = torch.cat(split_polyline_theta, dim=0)
        split_polyline_type = torch.cat(split_polyline_type, dim=0)
        split_polyline_side = torch.cat(split_polyline_side, dim=0)
        split_polygon_type = torch.cat(split_polygon_type, dim=0)
        pl_idx_list = torch.cat(pl_idx_list, dim=0)
        vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :]
        data['map_save'] = {}
        data['pt_token'] = {}
        data['map_save']['traj_pos'] = split_polyline_pos
        data['map_save']['traj_theta'] = split_polyline_theta[:, 0]  # torch.arctan2(vec[:, 1], vec[:, 0])
        data['map_save']['pl_idx_list'] = pl_idx_list
        data['pt_token']['type'] = split_polyline_type
        data['pt_token']['side'] = split_polyline_side
        data['pt_token']['pl_type'] = split_polygon_type
        data['pt_token']['num_nodes'] = split_polyline_pos.shape[0]
        return data

================================================
FILE: smart/datasets/scalable_dataset.py
================================================
import os
import pickle
from typing import Callable, List, Optional, Tuple, Union
import pandas as pd
from torch_geometric.data import Dataset
from smart.utils.log import Logging
import numpy as np
from .preprocess import TokenProcessor


def distance(point1, point2):
    return np.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2)


class MultiDataset(Dataset):
    def __init__(self,
                 root: str,
                 split: str,
                 raw_dir: List[str] = None,
                 processed_dir: List[str] = None,
                 transform: Optional[Callable] = None,
                 dim: int = 3,
                 num_historical_steps: int = 50,
                 num_future_steps: int = 60,
                 predict_unseen_agents: bool = False,
                 vector_repr: bool = True,
                 cluster: bool = False,
                 processor=None,
                 use_intention=False,
                 token_size=512) -> None:
        self.logger = Logging().log(level='DEBUG')
        self.root = root
        self.well_done = [0]
        if split not in ('train', 'val', 'test'):
            raise ValueError(f'{split} is not a valid split')
        self.split = split
        self.training = split == 'train'
        self.logger.debug("Starting loading dataset")
        self._raw_file_names = []
        self._raw_paths = []
        self._raw_file_dataset = []
        if raw_dir is not None:
            self._raw_dir = raw_dir
            for raw_dir in self._raw_dir:
                raw_dir = os.path.expanduser(os.path.normpath(raw_dir))
                dataset = "waymo"
                file_list = os.listdir(raw_dir)
                self._raw_file_names.extend(file_list)
                self._raw_paths.extend([os.path.join(raw_dir, f) for f in file_list])
                self._raw_file_dataset.extend([dataset for _ in range(len(file_list))])
        if self.root is not None:
            split_datainfo = os.path.join(root, "split_datainfo.pkl")
            with open(split_datainfo, 'rb+') as f:
                split_datainfo = pickle.load(f)
            if split == "test":
                split = "val"
            self._processed_file_names = split_datainfo[split]
        self.dim = dim
        self.num_historical_steps = num_historical_steps
        self._num_samples = len(self._processed_file_names) - 1 if processed_dir is not None else len(self._raw_file_names)
        self.logger.debug("The number of {} dataset is ".format(split) + str(self._num_samples))
        self.token_processor = TokenProcessor(2048)
        super(MultiDataset, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None)

    @property
    def raw_dir(self) -> str:
        return self._raw_dir

    @property
    def raw_paths(self) -> List[str]:
        return self._raw_paths

    @property
    def raw_file_names(self) -> Union[str, List[str], Tuple]:
        return self._raw_file_names

    @property
    def processed_file_names(self) -> Union[str, List[str], Tuple]:
        return self._processed_file_names

    def len(self) -> int:
        return self._num_samples

    def generate_ref_token(self):
        pass

    def get(self, idx: int):
        with open(self.raw_paths[idx], 'rb') as handle:
            data = pickle.load(handle)
        data = self.token_processor.preprocess(data)
        return data


================================================
FILE: smart/layers/__init__.py
================================================

from smart.layers.attention_layer import AttentionLayer
from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
from smart.layers.mlp_layer import MLPLayer


================================================
FILE: smart/layers/attention_layer.py
================================================

from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax

from smart.utils import weight_init


class AttentionLayer(MessagePassing):

    def __init__(self,
                 hidden_dim: int,
                 num_heads: int,
                 head_dim: int,
                 dropout: float,
                 bipartite: bool,
                 has_pos_emb: bool,
                 **kwargs) -> None:
        super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs)
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.has_pos_emb = has_pos_emb
        self.scale = head_dim ** -0.5

        self.to_q = nn.Linear(hidden_dim, head_dim * num_heads)
        self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
        self.to_v = nn.Linear(hidden_dim, head_dim * num_heads)
        if has_pos_emb:
            self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
            self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads)
        self.to_s = nn.Linear(hidden_dim, head_dim * num_heads)
        self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads)
        self.to_out = nn.Linear(head_dim * num_heads, hidden_dim)
        self.attn_drop = nn.Dropout(dropout)
        self.ff_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
        )
        if bipartite:
            self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
            self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim)
        else:
            self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
            self.attn_prenorm_x_dst = self.attn_prenorm_x_src
        if has_pos_emb:
            self.attn_prenorm_r = nn.LayerNorm(hidden_dim)
        self.attn_postnorm = nn.LayerNorm(hidden_dim)
        self.ff_prenorm = nn.LayerNorm(hidden_dim)
        self.ff_postnorm = nn.LayerNorm(hidden_dim)
        self.apply(weight_init)

    def forward(self,
                x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
                r: Optional[torch.Tensor],
                edge_index: torch.Tensor) -> torch.Tensor:
        if isinstance(x, torch.Tensor):
            x_src = x_dst = self.attn_prenorm_x_src(x)
        else:
            x_src, x_dst = x
            x_src = self.attn_prenorm_x_src(x_src)
            x_dst = self.attn_prenorm_x_dst(x_dst)
            x = x[1]
        if self.has_pos_emb and r is not None:
            r = self.attn_prenorm_r(r)
        x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index))
        x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x)))
        return x

    def message(self,
                q_i: torch.Tensor,
                k_j: torch.Tensor,
                v_j: torch.Tensor,
                r: Optional[torch.Tensor],
                index: torch.Tensor,
                ptr: Optional[torch.Tensor]) -> torch.Tensor:
        if self.has_pos_emb and r is not None:
            k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim)
            v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim)
        sim = (q_i * k_j).sum(dim=-1) * self.scale
        attn = softmax(sim, index, ptr)
        self.attention_weight = attn.sum(-1).detach()
        attn = self.attn_drop(attn)
        return v_j * attn.unsqueeze(-1)

    def update(self,
               inputs: torch.Tensor,
               x_dst: torch.Tensor) -> torch.Tensor:
        inputs = inputs.view(-1, self.num_heads * self.head_dim)
        g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1)))
        return inputs + g * (self.to_s(x_dst) - inputs)

    def _attn_block(self,
                    x_src: torch.Tensor,
                    x_dst: torch.Tensor,
                    r: Optional[torch.Tensor],
                    edge_index: torch.Tensor) -> torch.Tensor:
        q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim)
        k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim)
        v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim)
        agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r)
        return self.to_out(agg)

    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
        return self.ff_mlp(x)


================================================
FILE: smart/layers/fourier_embedding.py
================================================
import math
from typing import List, Optional
import torch
import torch.nn as nn

from smart.utils import weight_init


class FourierEmbedding(nn.Module):

    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 num_freq_bands: int) -> None:
        super(FourierEmbedding, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None
        self.mlps = nn.ModuleList(
            [nn.Sequential(
                nn.Linear(num_freq_bands * 2 + 1, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_dim, hidden_dim),
            )
                for _ in range(input_dim)])
        self.to_out = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.apply(weight_init)

    def forward(self,
                continuous_inputs: Optional[torch.Tensor] = None,
                categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        if continuous_inputs is None:
            if categorical_embs is not None:
                x = torch.stack(categorical_embs).sum(dim=0)
            else:
                raise ValueError('Both continuous_inputs and categorical_embs are None')
        else:
            x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi
            # Warning: if your data are noisy, don't use learnable sinusoidal embedding
            x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1)
            continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim
            for i in range(self.input_dim):
                continuous_embs[i] = self.mlps[i](x[:, i])
            x = torch.stack(continuous_embs).sum(dim=0)
            if categorical_embs is not None:
                x = x + torch.stack(categorical_embs).sum(dim=0)
        return self.to_out(x)


class MLPEmbedding(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int) -> None:
        super(MLPEmbedding, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LayerNorm(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim))
        self.apply(weight_init)

    def forward(self,
                continuous_inputs: Optional[torch.Tensor] = None,
                categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        if continuous_inputs is None:
            if categorical_embs is not None:
                x = torch.stack(categorical_embs).sum(dim=0)
            else:
                raise ValueError('Both continuous_inputs and categorical_embs are None')
        else:
            x = self.mlp(continuous_inputs)
            if categorical_embs is not None:
                x = x + torch.stack(categorical_embs).sum(dim=0)
        return x


================================================
FILE: smart/layers/mlp_layer.py
================================================

import torch
import torch.nn as nn

from smart.utils import weight_init


class MLPLayer(nn.Module):

    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 output_dim: int) -> None:
        super(MLPLayer, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, output_dim),
        )
        self.apply(weight_init)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)


================================================
FILE: smart/metrics/__init__.py
================================================

from smart.metrics.average_meter import AverageMeter
from smart.metrics.min_ade import minADE
from smart.metrics.min_fde import minFDE
from smart.metrics.next_token_cls import TokenCls


================================================
FILE: smart/metrics/average_meter.py
================================================

import torch
from torchmetrics import Metric


class AverageMeter(Metric):

    def __init__(self, **kwargs) -> None:
        super(AverageMeter, self).__init__(**kwargs)
        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')

    def update(self, val: torch.Tensor) -> None:
        self.sum += val.sum()
        self.count += val.numel()

    def compute(self) -> torch.Tensor:
        return self.sum / self.count


================================================
FILE: smart/metrics/min_ade.py
================================================

from typing import Optional

import torch
from torchmetrics import Metric

from smart.metrics.utils import topk
from smart.metrics.utils import valid_filter


class minMultiADE(Metric):

    def __init__(self,
                 max_guesses: int = 6,
                 **kwargs) -> None:
        super(minMultiADE, self).__init__(**kwargs)
        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
        self.max_guesses = max_guesses

    def update(self,
               pred: torch.Tensor,
               target: torch.Tensor,
               prob: Optional[torch.Tensor] = None,
               valid_mask: Optional[torch.Tensor] = None,
               keep_invalid_final_step: bool = True,
               min_criterion: str = 'FDE') -> None:
        pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
        pred_topk, _ = topk(self.max_guesses, pred, prob)
        if min_criterion == 'FDE':
            inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
            inds_best = torch.norm(
                pred_topk[torch.arange(pred.size(0)), :, inds_last] -
                target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
            self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
                          valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
        elif min_criterion == 'ADE':
            self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) *
                          valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
        else:
            raise ValueError('{} is not a valid criterion'.format(min_criterion))
        self.count += pred.size(0)

    def compute(self) -> torch.Tensor:
        return self.sum / self.count


class minADE(Metric):

    def __init__(self,
                 max_guesses: int = 6,
                 **kwargs) -> None:
        super(minADE, self).__init__(**kwargs)
        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
        self.max_guesses = max_guesses
        self.eval_timestep = 70

    def update(self,
               pred: torch.Tensor,
               target: torch.Tensor,
               prob: Optional[torch.Tensor] = None,
               valid_mask: Optional[torch.Tensor] = None,
               keep_invalid_final_step: bool = True,
               min_criterion: str = 'ADE') -> None:
        # pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
        # pred_topk, _ = topk(self.max_guesses, pred, prob)
        # if min_criterion == 'FDE':
        #     inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
        #     inds_best = torch.norm(
        #         pred[torch.arange(pred.size(0)), :, inds_last] -
        #         target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
        #     self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
        #                   valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
        # elif min_criterion == 'ADE':
        #     self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) *
        #                   valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
        # else:
        #     raise ValueError('{} is not a valid criterion'.format(min_criterion))
        eval_timestep = min(self.eval_timestep, pred.shape[1])
        self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum()
        self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum()

    def compute(self) -> torch.Tensor:
        return self.sum / self.count


================================================
FILE: smart/metrics/min_fde.py
================================================
from typing import Optional

import torch
from torchmetrics import Metric

from smart.metrics.utils import topk
from smart.metrics.utils import valid_filter


class minMultiFDE(Metric):

    def __init__(self,
                 max_guesses: int = 6,
                 **kwargs) -> None:
        super(minMultiFDE, self).__init__(**kwargs)
        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
        self.max_guesses = max_guesses

    def update(self,
               pred: torch.Tensor,
               target: torch.Tensor,
               prob: Optional[torch.Tensor] = None,
               valid_mask: Optional[torch.Tensor] = None,
               keep_invalid_final_step: bool = True) -> None:
        pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
        pred_topk, _ = topk(self.max_guesses, pred, prob)
        inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
        self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] -
                               target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2),
                               p=2, dim=-1).min(dim=-1)[0].sum()
        self.count += pred.size(0)

    def compute(self) -> torch.Tensor:
        return self.sum / self.count


class minFDE(Metric):

    def __init__(self,
                 max_guesses: int = 6,
                 **kwargs) -> None:
        super(minFDE, self).__init__(**kwargs)
        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
        self.max_guesses = max_guesses
        self.eval_timestep = 70

    def update(self,
               pred: torch.Tensor,
               target: torch.Tensor,
               prob: Optional[torch.Tensor] = None,
               valid_mask: Optional[torch.Tensor] = None,
               keep_invalid_final_step: bool = True) -> None:
        eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1
        self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) *
                      valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum()
        self.count += valid_mask[:, eval_timestep-1].sum()

    def compute(self) -> torch.Tensor:
        return self.sum / self.count


================================================
FILE: smart/metrics/next_token_cls.py
================================================
from typing import Optional

import torch
from torchmetrics import Metric

from smart.metrics.utils import topk
from smart.metrics.utils import valid_filter


class TokenCls(Metric):

    def __init__(self,
                 max_guesses: int = 6,
                 **kwargs) -> None:
        super(TokenCls, self).__init__(**kwargs)
        self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
        self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
        self.max_guesses = max_guesses

    def update(self,
               pred: torch.Tensor,
               target: torch.Tensor,
               valid_mask: Optional[torch.Tensor] = None) -> None:
        target = target[..., None]
        acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask
        self.sum += acc.sum()
        self.count += valid_mask.sum()

    def compute(self) -> torch.Tensor:
        return self.sum / self.count


================================================
FILE: smart/metrics/utils.py
================================================
from typing import Optional, Tuple

import torch
from torch_scatter import gather_csr
from torch_scatter import segment_csr


def topk(
        max_guesses: int,
        pred: torch.Tensor,
        prob: Optional[torch.Tensor] = None,
        ptr: Optional[torch.Tensor] = None,
        joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
    max_guesses = min(max_guesses, pred.size(1))
    if max_guesses == pred.size(1):
        if prob is not None:
            prob = prob / prob.sum(dim=-1, keepdim=True)
        else:
            prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
        return pred, prob
    else:
        if prob is not None:
            if joint:
                if ptr is None:
                    inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
                                           k=max_guesses, dim=-1, largest=True, sorted=True)[1]
                    inds_topk = inds_topk.repeat(pred.size(0), 1)
                else:
                    inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
                                                       reduce='mean'),
                                           k=max_guesses, dim=-1, largest=True, sorted=True)[1]
                    inds_topk = gather_csr(src=inds_topk, indptr=ptr)
            else:
                inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
            pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
            prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
            prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
        else:
            pred_topk = pred[:, :max_guesses]
            prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
        return pred_topk, prob_topk


def topkind(
        max_guesses: int,
        pred: torch.Tensor,
        prob: Optional[torch.Tensor] = None,
        ptr: Optional[torch.Tensor] = None,
        joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    max_guesses = min(max_guesses, pred.size(1))
    if max_guesses == pred.size(1):
        if prob is not None:
            prob = prob / prob.sum(dim=-1, keepdim=True)
        else:
            prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
        return pred, prob, None
    else:
        if prob is not None:
            if joint:
                if ptr is None:
                    inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
                                           k=max_guesses, dim=-1, largest=True, sorted=True)[1]
                    inds_topk = inds_topk.repeat(pred.size(0), 1)
                else:
                    inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
                                                       reduce='mean'),
                                           k=max_guesses, dim=-1, largest=True, sorted=True)[1]
                    inds_topk = gather_csr(src=inds_topk, indptr=ptr)
            else:
                inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
            pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
            prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
            prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
        else:
            pred_topk = pred[:, :max_guesses]
            prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
        return pred_topk, prob_topk, inds_topk


def valid_filter(
        pred: torch.Tensor,
        target: torch.Tensor,
        prob: Optional[torch.Tensor] = None,
        valid_mask: Optional[torch.Tensor] = None,
        ptr: Optional[torch.Tensor] = None,
        keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
                                                       torch.Tensor, torch.Tensor]:
    if valid_mask is None:
        valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool)
    if keep_invalid_final_step:
        filter_mask = valid_mask.any(dim=-1)
    else:
        filter_mask = valid_mask[:, -1]
    pred = pred[filter_mask]
    target = target[filter_mask]
    if prob is not None:
        prob = prob[filter_mask]
    valid_mask = valid_mask[filter_mask]
    if ptr is not None:
        num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum')
        ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,))
        torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:])
    else:
        ptr = target.new_tensor([0, target.size(0)])
    return pred, target, prob, valid_mask, ptr


def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):
    """

    Args:
        pred_trajs (batch_size, num_modes, num_timestamps, 7)
        pred_scores (batch_size, num_modes):
        dist_thresh (float):
        num_ret_modes (int, optional): Defaults to 6.

    Returns:
        ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
        ret_scores (batch_size, num_ret_modes)
        ret_idxs (batch_size, num_ret_modes)
    """
    batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
    pred_goals = pred_trajs[:, :, -1, :]
    dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1)
    nearby_neighbor = dist < dist_thresh
    pred_scores = nearby_neighbor.sum(dim=-1) / num_modes

    sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
    bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
    sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
    sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs]  # (batch_size, num_modes, num_timestamps, 7)
    sorted_pred_goals = sorted_pred_trajs[:, :, -1, :]  # (batch_size, num_modes, 7)

    dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
    point_cover_mask = (dist < dist_thresh)

    point_val = sorted_pred_scores.clone()  # (batch_size, N)
    point_val_selected = torch.zeros_like(point_val)  # (batch_size, N)

    ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
    ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
    ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
    bs_idxs = torch.arange(batch_size).type_as(ret_idxs)

    for k in range(num_ret_modes):
        cur_idx = point_val.argmax(dim=-1)  # (batch_size)
        ret_idxs[:, k] = cur_idx

        new_cover_mask = point_cover_mask[bs_idxs, cur_idx]  # (batch_size, N)
        point_val = point_val * (~new_cover_mask).float()  # (batch_size, N)
        point_val_selected[bs_idxs, cur_idx] = -1
        point_val += point_val_selected

        ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
        ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]

    bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)

    ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
    return ret_trajs, ret_scores, ret_idxs


def batch_nms(pred_trajs, pred_scores,
              dist_thresh, num_ret_modes=6,
              mode='static', speed=None):
    """

    Args:
        pred_trajs (batch_size, num_modes, num_timestamps, 7)
        pred_scores (batch_size, num_modes):
        dist_thresh (float):
        num_ret_modes (int, optional): Defaults to 6.

    Returns:
        ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
        ret_scores (batch_size, num_ret_modes)
        ret_idxs (batch_size, num_ret_modes)
    """
    batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape

    sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
    bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
    sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
    sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs]  # (batch_size, num_modes, num_timestamps, 7)
    sorted_pred_goals = sorted_pred_trajs[:, :, -1, :]  # (batch_size, num_modes, 7)

    if mode == "speed":
        scale = torch.ones(batch_size).to(sorted_pred_goals.device)
        lon_dist_thresh = 4 * scale
        lat_dist_thresh = 0.5 * scale
        lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1)
        lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1)
        point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None])
    else:
        dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
        point_cover_mask = (dist < dist_thresh)

    point_val = sorted_pred_scores.clone()  # (batch_size, N)
    point_val_selected = torch.zeros_like(point_val)  # (batch_size, N)

    ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
    ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
    ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
    bs_idxs = torch.arange(batch_size).type_as(ret_idxs)

    for k in range(num_ret_modes):
        cur_idx = point_val.argmax(dim=-1)  # (batch_size)
        ret_idxs[:, k] = cur_idx

        new_cover_mask = point_cover_mask[bs_idxs, cur_idx]  # (batch_size, N)
        point_val = point_val * (~new_cover_mask).float()  # (batch_size, N)
        point_val_selected[bs_idxs, cur_idx] = -1
        point_val += point_val_selected

        ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
        ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]

    bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)

    ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
    return ret_trajs, ret_scores, ret_idxs


def batch_nms_token(pred_trajs, pred_scores,
                    dist_thresh, num_ret_modes=6,
                    mode='static', speed=None):
    """
    Args:
        pred_trajs (batch_size, num_modes, num_timestamps, 7)
        pred_scores (batch_size, num_modes):
        dist_thresh (float):
        num_ret_modes (int, optional): Defaults to 6.

    Returns:
        ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
        ret_scores (batch_size, num_ret_modes)
        ret_idxs (batch_size, num_ret_modes)
    """
    batch_size, num_modes, num_feat_dim = pred_trajs.shape

    sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
    bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
    sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
    sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs]  # (batch_size, num_modes, num_timestamps, 7)

    if mode == "nearby":
        dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
        values, indices = torch.topk(dist, 5, dim=-1, largest=False)
        thresh_hold = values[..., -1]
        point_cover_mask = dist < thresh_hold[..., None]
    else:
        dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
        point_cover_mask = (dist < dist_thresh)

    point_val = sorted_pred_scores.clone()  # (batch_size, N)
    point_val_selected = torch.zeros_like(point_val)  # (batch_size, N)

    ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
    ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim)
    ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes)
    bs_idxs = torch.arange(batch_size).type_as(ret_idxs)

    for k in range(num_ret_modes):
        cur_idx = point_val.argmax(dim=-1)  # (batch_size)
        ret_idxs[:, k] = cur_idx

        new_cover_mask = point_cover_mask[bs_idxs, cur_idx]  # (batch_size, N)
        point_val = point_val * (~new_cover_mask).float()  # (batch_size, N)
        point_val_selected[bs_idxs, cur_idx] = -1
        point_val += point_val_selected

        ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx]
        ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]

    bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)

    ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
    return ret_goals, ret_scores, ret_idxs


================================================
FILE: smart/model/__init__.py
================================================
from smart.model.smart import SMART


================================================
FILE: smart/model/smart.py
================================================
import contextlib
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch_geometric.data import Batch
from torch_geometric.data import HeteroData
from smart.metrics import minADE
from smart.metrics import minFDE
from smart.metrics import TokenCls
from smart.modules import SMARTDecoder
from torch.optim.lr_scheduler import LambdaLR
import math
import numpy as np
import pickle
from collections import defaultdict
import os
from waymo_open_dataset.protos import sim_agents_submission_pb2


def cal_polygon_contour(x, y, theta, width, length):
    left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
    left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
    left_front = (left_front_x, left_front_y)

    right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
    right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
    right_front = (right_front_x, right_front_y)

    right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
    right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
    right_back = (right_back_x, right_back_y)

    left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
    left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
    left_back = (left_back_x, left_back_y)
    polygon_contour = [left_front, right_front, right_back, left_back]

    return polygon_contour


def joint_scene_from_states(states, object_ids) -> sim_agents_submission_pb2.JointScene:
    states = states.numpy()
    simulated_trajectories = []
    for i_object in range(len(object_ids)):
        simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory(
            center_x=states[i_object, :, 0], center_y=states[i_object, :, 1],
            center_z=states[i_object, :, 2], heading=states[i_object, :, 3],
            object_id=object_ids[i_object].item()
        ))
    return sim_agents_submission_pb2.JointScene(simulated_trajectories=simulated_trajectories)


class SMART(pl.LightningModule):

    def __init__(self, model_config) -> None:
        super(SMART, self).__init__()
        self.save_hyperparameters()
        self.model_config = model_config
        self.warmup_steps = model_config.warmup_steps
        self.lr = model_config.lr
        self.total_steps = model_config.total_steps
        self.dataset = model_config.dataset
        self.input_dim = model_config.input_dim
        self.hidden_dim = model_config.hidden_dim
        self.output_dim = model_config.output_dim
        self.output_head = model_config.output_head
        self.num_historical_steps = model_config.num_historical_steps
        self.num_future_steps = model_config.decoder.num_future_steps
        self.num_freq_bands = model_config.num_freq_bands
        self.vis_map = False
        self.noise = True
        module_dir = os.path.dirname(os.path.dirname(__file__))
        self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
        self.init_map_token()
        self.token_path = os.path.join(module_dir, 'tokens/cluster_frame_5_2048.pkl')
        token_data = self.get_trajectory_token()
        self.encoder = SMARTDecoder(
            dataset=model_config.dataset,
            input_dim=model_config.input_dim,
            hidden_dim=model_config.hidden_dim,
            num_historical_steps=model_config.num_historical_steps,
            num_freq_bands=model_config.num_freq_bands,
            num_heads=model_config.num_heads,
            head_dim=model_config.head_dim,
            dropout=model_config.dropout,
            num_map_layers=model_config.decoder.num_map_layers,
            num_agent_layers=model_config.decoder.num_agent_layers,
            pl2pl_radius=model_config.decoder.pl2pl_radius,
            pl2a_radius=model_config.decoder.pl2a_radius,
            a2a_radius=model_config.decoder.a2a_radius,
            time_span=model_config.decoder.time_span,
            map_token={'traj_src': self.map_token['traj_src']},
            token_data=token_data,
            token_size=model_config.decoder.token_size
        )
        self.minADE = minADE(max_guesses=1)
        self.minFDE = minFDE(max_guesses=1)
        self.TokenCls = TokenCls(max_guesses=1)

        self.test_predictions = dict()
        self.cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.map_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.inference_token = False
        self.rollout_num = 1

    def get_trajectory_token(self):
        token_data = pickle.load(open(self.token_path, 'rb'))
        self.trajectory_token = token_data['token']
        self.trajectory_token_traj = token_data['traj']
        self.trajectory_token_all = token_data['token_all']
        return token_data

    def init_map_token(self):
        self.argmin_sample_len = 3
        map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
        self.map_token = {'traj_src': map_token_traj['traj_src'], }
        traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1],
                                    self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0])
        indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long()
        self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float)
        self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
        self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float)

    def forward(self, data: HeteroData):
        res = self.encoder(data)
        return res

    def inference(self, data: HeteroData):
        res = self.encoder.inference(data)
        return res

    def maybe_autocast(self, dtype=torch.float16):
        enable_autocast = self.device != torch.device("cpu")

        if enable_autocast:
            return torch.cuda.amp.autocast(dtype=dtype)
        else:
            return contextlib.nullcontext()

    def training_step(self,
                      data,
                      batch_idx):
        data = self.match_token_map(data)
        data = self.sample_pt_pred(data)
        if isinstance(data, Batch):
            data['agent']['av_index'] += data['agent']['ptr'][:-1]
        pred = self(data)
        next_token_prob = pred['next_token_prob']
        next_token_idx_gt = pred['next_token_idx_gt']
        next_token_eval_mask = pred['next_token_eval_mask']
        cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
        loss = cls_loss
        self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
        self.log('cls_loss', cls_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
        return loss

    def validation_step(self,
                        data,
                        batch_idx):
        data = self.match_token_map(data)
        data = self.sample_pt_pred(data)
        if isinstance(data, Batch):
            data['agent']['av_index'] += data['agent']['ptr'][:-1]
        pred = self(data)
        next_token_idx = pred['next_token_idx']
        next_token_idx_gt = pred['next_token_idx_gt']
        next_token_eval_mask = pred['next_token_eval_mask']
        next_token_prob = pred['next_token_prob']
        cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
        loss = cls_loss
        self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
                        valid_mask=next_token_eval_mask[next_token_eval_mask])
        self.log('val_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
        self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)

        eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1]  # * (data['agent']['category'] == 3)
        if self.inference_token:
            pred = self.inference(data)
            pos_a = pred['pos_a']
            gt = pred['gt']
            valid_mask = data['agent']['valid_mask'][:, self.num_historical_steps:]
            pred_traj = pred['pred_traj']
            # next_token_idx = pred['next_token_idx'][..., None]
            # next_token_idx_gt = pred['next_token_idx_gt'][:, 2:]
            # next_token_eval_mask = pred['next_token_eval_mask'][:, 2:]
            # next_token_eval_mask[:, 1:] = False
            # self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
            #                      valid_mask=next_token_eval_mask[next_token_eval_mask])
            # self.log('val_inference_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
            eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1]

            self.minADE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])
            self.minFDE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])
            # print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute())

            self.log('val_minADE', self.minADE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)
            self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)

    def on_validation_start(self):
        self.gt = []
        self.pred = []
        self.scenario_rollouts = []
        self.batch_metric = defaultdict(list)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)

        def lr_lambda(current_step):
            if current_step + 1 < self.warmup_steps:
                return float(current_step + 1) / float(max(1, self.warmup_steps))
            return max(
                0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))))
            )

        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
        return [optimizer], [lr_scheduler]

    def load_params_from_file(self, filename, logger, to_cpu=False):
        if not os.path.isfile(filename):
            raise FileNotFoundError

        logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
        loc_type = torch.device('cpu') if to_cpu else None
        checkpoint = torch.load(filename, map_location=loc_type)
        model_state_disk = checkpoint['state_dict']

        version = checkpoint.get("version", None)
        if version is not None:
            logger.info('==> Checkpoint trained from version: %s' % version)

        logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}')
        model_state = self.state_dict()
        model_state_disk_filter = {}
        for key, val in model_state_disk.items():
            if key in model_state and model_state_disk[key].shape == model_state[key].shape:
                model_state_disk_filter[key] = val
            else:
                if key not in model_state:
                    print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}')
                else:
                    print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}')

        model_state_disk = model_state_disk_filter

        missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False)

        logger.info(f'Missing keys: {missing_keys}')
        logger.info(f'The number of missing keys: {len(missing_keys)}')
        logger.info(f'The number of unexpected keys: {len(unexpected_keys)}')
        logger.info('==> Done (total keys %d)' % (len(model_state)))

        epoch = checkpoint.get('epoch', -1)
        it = checkpoint.get('it', 0.0)

        return it, epoch

    def match_token_map(self, data):
        traj_pos = data['map_save']['traj_pos'].to(torch.float)
        traj_theta = data['map_save']['traj_theta'].to(torch.float)
        pl_idx_list = data['map_save']['pl_idx_list']
        token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device)
        token_src = self.map_token['traj_src'].to(traj_pos.device)
        max_traj_len = self.map_token['traj_src'].shape[1]
        pl_num = traj_pos.shape[0]

        pt_token_pos = traj_pos[:, 0, :].clone()
        pt_token_orientation = traj_theta.clone()
        cos, sin = traj_theta.cos(), traj_theta.sin()
        rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
        rot_mat[..., 0, 0] = cos
        rot_mat[..., 0, 1] = -sin
        rot_mat[..., 1, 0] = sin
        rot_mat[..., 1, 1] = cos
        traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
        distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1))
        pt_token_id = torch.argmin(distance, dim=1)

        if self.noise:
            topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]
            sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
            pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)

        cos, sin = traj_theta.cos(), traj_theta.sin()
        rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
        rot_mat[..., 0, 0] = cos
        rot_mat[..., 0, 1] = sin
        rot_mat[..., 1, 0] = -sin
        rot_mat[..., 1, 1] = cos
        token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
                                    rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
        token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)

        pl_idx_full = pl_idx_list.clone()
        token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
        count_nums = []
        for pl in pl_idx_full.unique():
            pt = token2pl[0, token2pl[1, :] == pl]
            left_side = (data['pt_token']['side'][pt] == 0).sum()
            right_side = (data['pt_token']['side'][pt] == 1).sum()
            center_side = (data['pt_token']['side'][pt] == 2).sum()
            count_nums.append(torch.Tensor([left_side, right_side, center_side]))
        count_nums = torch.stack(count_nums, dim=0)
        num_polyline = int(count_nums.max().item())
        traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
        idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
        idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)  #
        counts_num_expanded = count_nums.unsqueeze(-1)
        mask_update = idx_matrix < counts_num_expanded
        traj_mask[mask_update] = True

        data['pt_token']['traj_mask'] = traj_mask
        data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
                                                                            device=traj_pos.device, dtype=torch.float)], dim=-1)
        data['pt_token']['orientation'] = pt_token_orientation
        data['pt_token']['height'] = data['pt_token']['position'][:, -1]
        data[('pt_token', 'to', 'map_polygon')] = {}
        data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl
        data['pt_token']['token_idx'] = pt_token_id
        return data

    def sample_pt_pred(self, data):
        traj_mask = data['pt_token']['traj_mask']
        raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)
        masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)]
        masked_pt_index = torch.sort(masked_pt_index, -1)[0]
        pt_valid_mask = traj_mask.clone()
        pt_valid_mask.scatter_(2, masked_pt_index, False)
        pt_pred_mask = traj_mask.clone()
        pt_pred_mask.scatter_(2, masked_pt_index, False)
        tmp_mask = pt_pred_mask.clone()
        tmp_mask[:, :, :] = True
        tmp_mask.scatter_(2, masked_pt_index-1, False)
        pt_pred_mask.masked_fill_(tmp_mask, False)
        pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)
        pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)

        data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]
        data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]
        data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]

        return data


================================================
FILE: smart/modules/__init__.py
================================================
from smart.modules.smart_decoder import SMARTDecoder
from smart.modules.map_decoder import SMARTMapDecoder
from smart.modules.agent_decoder import SMARTAgentDecoder


================================================
FILE: smart/modules/agent_decoder.py
================================================
import pickle
from typing import Dict, Mapping, Optional
import torch
import torch.nn as nn
from smart.layers import MLPLayer
from smart.layers.attention_layer import AttentionLayer
from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
from torch_cluster import radius, radius_graph
from torch_geometric.data import Batch, HeteroData
from torch_geometric.utils import dense_to_sparse, subgraph
from smart.utils import angle_between_2d_vectors, weight_init, wrap_angle
import math


def cal_polygon_contour(x, y, theta, width, length):
    left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
    left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
    left_front = (left_front_x, left_front_y)

    right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
    right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
    right_front = (right_front_x, right_front_y)

    right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
    right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
    right_back = (right_back_x, right_back_y)

    left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
    left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
    left_back = (left_back_x, left_back_y)
    polygon_contour = [left_front, right_front, right_back, left_back]

    return polygon_contour


class SMARTAgentDecoder(nn.Module):

    def __init__(self,
                 dataset: str,
                 input_dim: int,
                 hidden_dim: int,
                 num_historical_steps: int,
                 time_span: Optional[int],
                 pl2a_radius: float,
                 a2a_radius: float,
                 num_freq_bands: int,
                 num_layers: int,
                 num_heads: int,
                 head_dim: int,
                 dropout: float,
                 token_data: Dict,
                 token_size=512) -> None:
        super(SMARTAgentDecoder, self).__init__()
        self.dataset = dataset
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_historical_steps = num_historical_steps
        self.time_span = time_span if time_span is not None else num_historical_steps
        self.pl2a_radius = pl2a_radius
        self.a2a_radius = a2a_radius
        self.num_freq_bands = num_freq_bands
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.dropout = dropout

        input_dim_x_a = 2
        input_dim_r_t = 4
        input_dim_r_pt2a = 3
        input_dim_r_a2a = 3
        input_dim_token = 8

        self.type_a_emb = nn.Embedding(4, hidden_dim)
        self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim)

        self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
        self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
        self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
                                           num_freq_bands=num_freq_bands)
        self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
                                          num_freq_bands=num_freq_bands)
        self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
        self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
        self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
        self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim)

        self.t_attn_layers = nn.ModuleList(
            [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
                            bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
        )
        self.pt2a_attn_layers = nn.ModuleList(
            [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
                            bipartite=True, has_pos_emb=True) for _ in range(num_layers)]
        )
        self.a2a_attn_layers = nn.ModuleList(
            [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
                            bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
        )
        self.token_size = token_size
        self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
                                           output_dim=self.token_size)
        self.trajectory_token = token_data['token']
        self.trajectory_token_traj = token_data['traj']
        self.trajectory_token_all = token_data['token_all']
        self.apply(weight_init)
        self.shift = 5
        self.beam_size = 5
        self.hist_mask = True

    def transform_rel(self, token_traj, prev_pos, prev_heading=None):
        if prev_heading is None:
            diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
            prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])

        num_agent, num_step, traj_num, traj_dim = token_traj.shape
        cos, sin = prev_heading.cos(), prev_heading.sin()
        rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
        rot_mat[:, :, 0, 0] = cos
        rot_mat[:, :, 0, 1] = -sin
        rot_mat[:, :, 1, 0] = sin
        rot_mat[:, :, 1, 1] = cos
        agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
        agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
        return agent_pred_rel

    def agent_token_embedding(self, data, agent_category, agent_token_index, pos_a, head_vector_a, inference=False):
        num_agent, num_step, traj_dim = pos_a.shape
        motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),
                                     pos_a[:, 1:] - pos_a[:, :-1]], dim=1)

        agent_type = data['agent']['type']
        veh_mask = (agent_type == 0)
        cyc_mask = (agent_type == 2)
        ped_mask = (agent_type == 1)
        trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
        self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1))
        trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
        self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))
        trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
        self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))

        if inference:
            agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
            trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(
                torch.float)
            trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(
                torch.float)
            trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(
                torch.float)
            agent_token_traj_all[veh_mask] = torch.cat(
                [trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
            agent_token_traj_all[ped_mask] = torch.cat(
                [trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
            agent_token_traj_all[cyc_mask] = torch.cat(
                [trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)

        agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
        agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]
        agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]
        agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]

        agent_token_traj = torch.zeros((num_agent, num_step, self.token_size, 4, 2), device=pos_a.device)
        agent_token_traj[veh_mask] = trajectory_token_veh
        agent_token_traj[ped_mask] = trajectory_token_ped
        agent_token_traj[cyc_mask] = trajectory_token_cyc

        vel = data['agent']['token_velocity']

        categorical_embs = [
            self.type_a_emb(data['agent']['type'].long()).repeat_interleave(repeats=num_step,
                                                                            dim=0),

            self.shape_emb(data['agent']['shape'][:, self.num_historical_steps - 1, :]).repeat_interleave(
                repeats=num_step,
                dim=0)
        ]
        feature_a = torch.stack(
            [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
             angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),
             ], dim=-1)

        x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
                           categorical_embs=categorical_embs)
        x_a = x_a.view(-1, num_step, self.hidden_dim)

        feat_a = torch.cat((agent_token_emb, x_a), dim=-1)
        feat_a = self.fusion_emb(feat_a)

        if inference:
            return feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs
        else:
            return feat_a, agent_token_traj

    def agent_predict_next(self, data, agent_category, feat_a):
        num_agent, num_step, traj_dim = data['agent']['token_pos'].shape
        agent_type = data['agent']['type']
        veh_mask = (agent_type == 0)  # * agent_category==3
        cyc_mask = (agent_type == 2)  # * agent_category==3
        ped_mask = (agent_type == 1)  # * agent_category==3
        token_res = torch.zeros((num_agent, num_step, self.token_size), device=agent_category.device)
        token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])
        token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])
        token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])
        return token_res

    def agent_predict_next_inf(self, data, agent_category, feat_a):
        num_agent, traj_dim = feat_a.shape
        agent_type = data['agent']['type']

        veh_mask = (agent_type == 0)  # * agent_category==3
        cyc_mask = (agent_type == 2)  # * agent_category==3
        ped_mask = (agent_type == 1)  # * agent_category==3

        token_res = torch.zeros((num_agent, self.token_size), device=agent_category.device)
        token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])
        token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])
        token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])

        return token_res

    def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent, mask, inference_mask=None):
        pos_t = pos_a.reshape(-1, self.input_dim)
        head_t = head_a.reshape(-1)
        head_vector_t = head_vector_a.reshape(-1, 2)
        hist_mask = mask.clone()

        if self.hist_mask and self.training:
            hist_mask[
                torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
            mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
        elif inference_mask is not None:
            mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
        else:
            mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)

        edge_index_t = dense_to_sparse(mask_t)[0]
        edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]]
        edge_index_t = edge_index_t[:, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift]
        rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
        rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
        r_t = torch.stack(
            [torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
             angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
             rel_head_t,
             edge_index_t[0] - edge_index_t[1]], dim=-1)
        r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
        return edge_index_t, r_t

    def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s, mask_s):
        pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
        head_s = head_a.transpose(0, 1).reshape(-1)
        head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
        edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
                                      max_num_neighbors=300)
        edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0]
        rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
        rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
        r_a2a = torch.stack(
            [torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
             angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
             rel_head_a2a], dim=-1)
        r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
        return edge_index_a2a, r_a2a

    def build_map2agent_edge(self, data, num_step, agent_category, pos_a, head_a, head_vector_a, mask,
                             batch_s, batch_pl):
        mask_pl2a = mask.clone()
        mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
        pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
        head_s = head_a.transpose(0, 1).reshape(-1)
        head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
        pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
        orient_pl = data['pt_token']['orientation'].contiguous()
        pos_pl = pos_pl.repeat(num_step, 1)
        orient_pl = orient_pl.repeat(num_step)
        edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
                                 batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
        edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]]
        rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
        rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
        r_pl2a = torch.stack(
            [torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
             angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
             rel_orient_pl2a], dim=-1)
        r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
        return edge_index_pl2a, r_pl2a

    def forward(self,
                data: HeteroData,
                map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        pos_a = data['agent']['token_pos']
        head_a = data['agent']['token_heading']
        head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
        num_agent, num_step, traj_dim = pos_a.shape
        agent_category = data['agent']['category']
        agent_token_index = data['agent']['token_idx']
        feat_a, agent_token_traj = self.agent_token_embedding(data, agent_category, agent_token_index,
                                                              pos_a, head_vector_a)

        agent_valid_mask = data['agent']['agent_valid_mask'].clone()
        # eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
        # agent_valid_mask[~eval_mask] = False
        mask = agent_valid_mask
        edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask)

        if isinstance(data, Batch):
            batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t
                                 for t in range(num_step)], dim=0)
            batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t
                                  for t in range(num_step)], dim=0)
        else:
            batch_s = torch.arange(num_step,
                                   device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])
            batch_pl = torch.arange(num_step,
                                    device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])

        mask_s = mask.transpose(0, 1).reshape(-1)
        edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, batch_s, mask_s)
        mask[agent_category != 3] = False
        edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,
                                                            head_vector_a, mask, batch_s, batch_pl)

        for i in range(self.num_layers):
            feat_a = feat_a.reshape(-1, self.hidden_dim)
            feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
            feat_a = feat_a.reshape(-1, num_step,
                                    self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
            feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(
                repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
                    -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
            feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
            feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)

        num_agent, num_step, hidden_dim, traj_num, traj_dim = agent_token_traj.shape
        next_token_prob = self.token_predict_head(feat_a)
        next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
        _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)

        next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
        next_token_eval_mask = mask.clone()
        next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)
        next_token_eval_mask[:, -1] = False

        return {'x_a': feat_a,
                'next_token_idx': next_token_idx,
                'next_token_prob': next_token_prob,
                'next_token_idx_gt': next_token_index_gt,
                'next_token_eval_mask': next_token_eval_mask,
                }

    def inference(self,
                  data: HeteroData,
                  map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
        pos_a = data['agent']['token_pos'].clone()
        head_a = data['agent']['token_heading'].clone()
        num_agent, num_step, traj_dim = pos_a.shape
        pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
        head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
        head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)

        agent_valid_mask = data['agent']['agent_valid_mask'].clone()
        agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
        agent_valid_mask[~eval_mask] = False
        agent_token_index = data['agent']['token_idx']
        agent_category = data['agent']['category']
        feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding(
            data,
            agent_category,
            agent_token_index,
            pos_a,
            head_vector_a,
            inference=True)

        agent_type = data["agent"]["type"]
        veh_mask = (agent_type == 0)  # * agent_category==3
        cyc_mask = (agent_type == 2)  # * agent_category==3
        ped_mask = (agent_type == 1)  # * agent_category==3
        av_mask = data["agent"]["av_index"]

        self.num_recurrent_steps_val = data["agent"]['position'].shape[1]-self.num_historical_steps
        pred_traj = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, 2, device=feat_a.device)
        pred_head = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, device=feat_a.device)
        pred_prob = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val // self.shift, device=feat_a.device)
        next_token_idx_list = []
        mask = agent_valid_mask.clone()
        feat_a_t_dict = {}
        for t in range(self.num_recurrent_steps_val // self.shift):
            if t == 0:
                inference_mask = mask.clone()
                inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False
            else:
                inference_mask = torch.zeros_like(mask)
                inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True
            edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask, inference_mask)
            if isinstance(data, Batch):
                batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t
                                     for t in range(num_step)], dim=0)
                batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t
                                      for t in range(num_step)], dim=0)
            else:
                batch_s = torch.arange(num_step,
                                       device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])
                batch_pl = torch.arange(num_step,
                                        device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
            # In the inference stage, we only infer the current stage for recurrent
            edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,
                                                                head_vector_a,
                                                                inference_mask, batch_s,
                                                                batch_pl)
            mask_s = inference_mask.transpose(0, 1).reshape(-1)
            edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a,
                                                                batch_s, mask_s)

            for i in range(self.num_layers):
                if i in feat_a_t_dict:
                    feat_a = feat_a_t_dict[i]
                feat_a = feat_a.reshape(-1, self.hidden_dim)
                feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
                feat_a = feat_a.reshape(-1, num_step,
                                        self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
                feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(
                    repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
                        -1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
                feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
                feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)

                if i+1 not in feat_a_t_dict:
                    feat_a_t_dict[i+1] = feat_a
                else:
                    feat_a_t_dict[i+1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]

            next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])

            next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)

            topk_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1)

            expanded_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)
            next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_index)

            theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
            cos, sin = theta.cos(), theta.sin()
            rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device)
            rot_mat[:, 0, 0] = cos
            rot_mat[:, 0, 1] = sin
            rot_mat[:, 1, 0] = -sin
            rot_mat[:, 1, 1] = cos
            agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2),
                                       rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view(
                                           -1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2)
            agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...]

            sample_index = torch.multinomial(topk_prob, 1).to(agent_pred_rel.device)
            agent_pred_rel = agent_pred_rel.gather(dim=1,
                                                   index=sample_index[..., None, None, None].expand(-1, -1, 6, 4,
                                                                                                    2))[:, 0, ...]
            pred_prob[:, t] = topk_prob.gather(dim=-1, index=sample_index)[:, 0]
            pred_traj[:, t * 5:(t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2)
            diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :]
            pred_head[:, t * 5:(t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])

            pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1)
            diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :]
            theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
            head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta
            next_token_idx = next_token_idx.gather(dim=1, index=sample_index)
            next_token_idx = next_token_idx.squeeze(-1)
            next_token_idx_list.append(next_token_idx[:, None])
            agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[
                next_token_idx[veh_mask]]
            agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[
                next_token_idx[ped_mask]]
            agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[
                next_token_idx[cyc_mask]]
            motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),
                                         pos_a[:, 1:] - pos_a[:, :-1]], dim=1)

            head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)

            vel = motion_vector_a.clone() / (0.1 * self.shift)
            vel[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0
            motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0
            x_a = torch.stack(
                [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
                 angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1)

            x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)),
                               categorical_embs=categorical_embs)
            x_a = x_a.view(-1, num_step, self.hidden_dim)

            feat_a = torch.cat((agent_token_emb, x_a), dim=-1)
            feat_a = self.fusion_emb(feat_a)

        agent_valid_mask[agent_category != 3] = False

        return {
            'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:],
            'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:],
            'gt': data['agent']['position'][:, self.num_historical_steps:, :self.input_dim].contiguous(),
            'valid_mask': agent_valid_mask[:, self.num_historical_steps:],
            'pred_traj': pred_traj,
            'pred_head': pred_head,
            'next_token_idx': torch.cat(next_token_idx_list, dim=-1),
            'next_token_idx_gt': agent_token_index.roll(shifts=-1, dims=1),
            'next_token_eval_mask': data['agent']['agent_valid_mask'],
            'pred_prob': pred_prob,
            'vel': vel
        }


================================================
FILE: smart/modules/map_decoder.py
================================================
import os.path
from typing import Dict
import torch
import torch.nn as nn
from torch_cluster import radius_graph
from torch_geometric.data import Batch
from torch_geometric.data import HeteroData
from torch_geometric.utils import dense_to_sparse, subgraph
from smart.utils.nan_checker import check_nan_inf
from smart.layers.attention_layer import AttentionLayer
from smart.layers import MLPLayer
from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
from smart.utils import angle_between_2d_vectors
from smart.utils import merge_edges
from smart.utils import weight_init
from smart.utils import wrap_angle
import pickle


class SMARTMapDecoder(nn.Module):

    def __init__(self,
                 dataset: str,
                 input_dim: int,
                 hidden_dim: int,
                 num_historical_steps: int,
                 pl2pl_radius: float,
                 num_freq_bands: int,
                 num_layers: int,
                 num_heads: int,
                 head_dim: int,
                 dropout: float,
                 map_token) -> None:
        super(SMARTMapDecoder, self).__init__()
        self.dataset = dataset
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_historical_steps = num_historical_steps
        self.pl2pl_radius = pl2pl_radius
        self.num_freq_bands = num_freq_bands
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.dropout = dropout

        if input_dim == 2:
            input_dim_r_pt2pt = 3
        elif input_dim == 3:
            input_dim_r_pt2pt = 4
        else:
            raise ValueError('{} is not a valid dimension'.format(input_dim))

        self.type_pt_emb = nn.Embedding(17, hidden_dim)
        self.side_pt_emb = nn.Embedding(4, hidden_dim)
        self.polygon_type_emb = nn.Embedding(4, hidden_dim)
        self.light_pl_emb = nn.Embedding(4, hidden_dim)

        self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim,
                                            num_freq_bands=num_freq_bands)
        self.pt2pt_layers = nn.ModuleList(
            [AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
                            bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
        )
        self.token_size = 1024
        self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
                                           output_dim=self.token_size)
        input_dim_token = 22
        self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
        self.map_token = map_token
        self.apply(weight_init)
        self.mask_pt = False

    def maybe_autocast(self, dtype=torch.float32):
        return torch.cuda.amp.autocast(dtype=dtype)

    def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
        pt_valid_mask = data['pt_token']['pt_valid_mask']
        pt_pred_mask = data['pt_token']['pt_pred_mask']
        pt_target_mask = data['pt_token']['pt_target_mask']
        mask_s = pt_valid_mask

        pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous()
        orient_pt = data['pt_token']['orientation'].contiguous()
        orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1)
        token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float)
        pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1))
        pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']]

        if self.input_dim == 2:
            x_pt = pt_token_emb
        elif self.input_dim == 3:
            x_pt = pt_token_emb
        else:
            raise ValueError('{} is not a valid dimension'.format(self.input_dim))

        token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index']
        token_light_type = data['map_polygon']['light_type'][token2pl[1]]
        x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()),
                                 self.polygon_type_emb(data['pt_token']['pl_type'].long()),
                                 self.light_pl_emb(token_light_type.long()),]
        x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0)
        edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius,
                                        batch=data['pt_token']['batch'] if isinstance(data, Batch) else None,
                                        loop=False, max_num_neighbors=100)
        if self.mask_pt:
            edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0]
        rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]]
        rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]])
        if self.input_dim == 2:
            r_pt2pt = torch.stack(
                [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
                 angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
                                          nbr_vector=rel_pos_pt2pt[:, :2]),
                 rel_orient_pt2pt], dim=-1)
        elif self.input_dim == 3:
            r_pt2pt = torch.stack(
                [torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
                 angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
                                          nbr_vector=rel_pos_pt2pt[:, :2]),
                 rel_pos_pt2pt[:, -1],
                 rel_orient_pt2pt], dim=-1)
        else:
            raise ValueError('{} is not a valid dimension'.format(self.input_dim))
        r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None)
        for i in range(self.num_layers):
            x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt)

        next_token_prob = self.token_predict_head(x_pt[pt_pred_mask])
        next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
        _, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
        next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask]

        return {
            'x_pt': x_pt,
            'map_next_token_idx': next_token_idx,
            'map_next_token_prob': next_token_prob,
            'map_next_token_idx_gt': next_token_index_gt,
            'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask]
        }


================================================
FILE: smart/modules/smart_decoder.py
================================================
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch_geometric.data import HeteroData
from smart.modules.agent_decoder import SMARTAgentDecoder
from smart.modules.map_decoder import SMARTMapDecoder


class SMARTDecoder(nn.Module):

    def __init__(self,
                 dataset: str,
                 input_dim: int,
                 hidden_dim: int,
                 num_historical_steps: int,
                 pl2pl_radius: float,
                 time_span: Optional[int],
                 pl2a_radius: float,
                 a2a_radius: float,
                 num_freq_bands: int,
                 num_map_layers: int,
                 num_agent_layers: int,
                 num_heads: int,
                 head_dim: int,
                 dropout: float,
                 map_token: Dict,
                 token_data: Dict,
                 use_intention=False,
                 token_size=512) -> None:
        super(SMARTDecoder, self).__init__()
        self.map_encoder = SMARTMapDecoder(
            dataset=dataset,
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_historical_steps=num_historical_steps,
            pl2pl_radius=pl2pl_radius,
            num_freq_bands=num_freq_bands,
            num_layers=num_map_layers,
            num_heads=num_heads,
            head_dim=head_dim,
            dropout=dropout,
            map_token=map_token
        )
        self.agent_encoder = SMARTAgentDecoder(
            dataset=dataset,
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_historical_steps=num_historical_steps,
            time_span=time_span,
            pl2a_radius=pl2a_radius,
            a2a_radius=a2a_radius,
            num_freq_bands=num_freq_bands,
            num_layers=num_agent_layers,
            num_heads=num_heads,
            head_dim=head_dim,
            dropout=dropout,
            token_size=token_size,
            token_data=token_data
        )
        self.map_enc = None

    def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
        map_enc = self.map_encoder(data)
        agent_enc = self.agent_encoder(data, map_enc)
        return {**map_enc, **agent_enc}

    def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:
        map_enc = self.map_encoder(data)
        agent_enc = self.agent_encoder.inference(data, map_enc)
        return {**map_enc, **agent_enc}

    def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]:
        agent_enc = self.agent_encoder.inference(data, map_enc)
        return {**map_enc, **agent_enc}


================================================
FILE: smart/preprocess/__init__.py
================================================


================================================
FILE: smart/preprocess/preprocess.py
================================================
import numpy as np
import pandas as pd
import os
import torch
from typing import Any, Dict, List, Optional

predict_unseen_agents = False
vector_repr = True
_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']
_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
                'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
                'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
                'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
_point_sides = ['LEFT', 'RIGHT', 'CENTER']
_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
_polygon_is_intersections = [True, False, None]


Lane_type_hash = {
    4: "BIKE",
    3: "VEHICLE",
    2: "VEHICLE",
    1: "BUS"
}

boundary_type_hash = {
        5: "UNKNOWN",
        6: "DASHED_WHITE",
        7: "SOLID_WHITE",
        8: "DOUBLE_DASH_WHITE",
        9: "DASHED_YELLOW",
        10: "DOUBLE_DASH_YELLOW",
        11: "SOLID_YELLOW",
        12: "DOUBLE_SOLID_YELLOW",
        13: "DASH_SOLID_YELLOW",
        14: "UNKNOWN",
        15: "EDGE",
        16: "EDGE"
}


def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:
    if not predict_unseen_agents:  # filter out agents that are unseen during the historical time steps
        historical_df = df[df['timestep'] == num_historical_steps-1]
        agent_ids = list(historical_df['track_id'].unique())
        df = df[df['track_id'].isin(agent_ids)]
    else:
        agent_ids = list(df['track_id'].unique())

    num_agents = len(agent_ids)
    # initialization
    valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
    current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
    predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
    agent_id: List[Optional[str]] = [None] * num_agents
    agent_type = torch.zeros(num_agents, dtype=torch.uint8)
    agent_category = torch.zeros(num_agents, dtype=torch.uint8)
    position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
    heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
    velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
    shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)

    for track_id, track_df in df.groupby('track_id'):
        agent_idx = agent_ids.index(track_id)
        agent_steps = track_df['timestep'].values

        valid_mask[agent_idx, agent_steps] = True
        current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]
        predict_mask[agent_idx, agent_steps] = True
        if vector_repr:  # a time step t is valid only when both t and t-1 are valid
            valid_mask[agent_idx, 1: num_historical_steps] = (
                valid_mask[agent_idx, :num_historical_steps - 1] &
                valid_mask[agent_idx, 1: num_historical_steps])
            valid_mask[agent_idx, 0] = False
        predict_mask[agent_idx, :num_historical_steps] = False
        if not current_valid_mask[agent_idx]:
            predict_mask[agent_idx, num_historical_steps:] = False

        agent_id[agent_idx] = track_id
        agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
        agent_category[agent_idx] = track_df['object_category'].values[0]
        position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,
                                                                          track_df['position_y'].values,
                                                                          track_df['position_z'].values],
                                                                         axis=-1)).float()
        heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()
        velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,
                                                                          track_df['velocity_y'].values],
                                                                         axis=-1)).float()
        shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,
                                                                       track_df['width'].values,
                                                                       track_df["height"].values],
                                                                      axis=-1)).float()
    av_idx = agent_id.index(av_id)

    return {
        'num_nodes': num_agents,
        'av_index': av_idx,
        'valid_mask': valid_mask,
        'predict_mask': predict_mask,
        'id': agent_id,
        'type': agent_type,
        'category': agent_category,
        'position': position,
        'heading': heading,
        'velocity': velocity,
        'shape': shape
    }

================================================
FILE: smart/tokens/__init__.py
================================================


================================================
FILE: smart/transforms/__init__.py
================================================
from smart.transforms.target_builder import WaymoTargetBuilder


================================================
FILE: smart/transforms/target_builder.py
================================================

import numpy as np
import torch
from torch_geometric.data import HeteroData
from torch_geometric.transforms import BaseTransform
from smart.utils import wrap_angle
from smart.utils.log import Logging


def to_16(data):
    if isinstance(data, dict):
        for key, value in data.items():
            new_value = to_16(value)
            data[key] = new_value
    if isinstance(data, torch.Tensor):
        if data.dtype == torch.float32:
            data = data.to(torch.float16)
    return data


def tofloat32(data):
    for name in data:
        value = data[name]
        if isinstance(value, dict):
            value = tofloat32(value)
        elif isinstance(value, torch.Tensor) and value.dtype == torch.float64:
            value = value.to(torch.float32)
        data[name] = value
    return data


class WaymoTargetBuilder(BaseTransform):

    def __init__(self,
                 num_historical_steps: int,
                 num_future_steps: int,
                 mode="train") -> None:
        self.num_historical_steps = num_historical_steps
        self.num_future_steps = num_future_steps
        self.mode = mode
        self.num_features = 3
        self.augment = False
        self.logger = Logging().log(level='DEBUG')

    def score_ego_agent(self, agent):
        av_index = agent['av_index']
        agent["category"][av_index] = 5
        return agent

    def clip(self, agent, max_num=32):
        av_index = agent["av_index"]
        valid = agent['valid_mask']
        ego_pos = agent["position"][av_index]
        obstacle_mask = agent['type'] == 3
        distance = torch.norm(agent["position"][:, self.num_historical_steps-1, :2] - ego_pos[self.num_historical_steps-1, :2], dim=-1)  # keep the closest 100 vehicles near the ego car
        distance[obstacle_mask] = 10e5
        sort_idx = distance.sort()[1]
        mask = torch.zeros(valid.shape[0])
        mask[sort_idx[:max_num]] = 1
        mask = mask.to(torch.bool)
        mask[av_index] = True
        new_av_index = mask[:av_index].sum()
        agent["num_nodes"] = int(mask.sum())
        agent["av_index"] = int(new_av_index)
        excluded = ["num_nodes", "av_index", "ego"]
        for key, val in agent.items():
            if key in excluded:
                continue
            if key == "id":
                val = list(np.array(val)[mask])
                agent[key] = val
                continue
            if len(val.size()) > 1:
                agent[key] = val[mask, ...]
            else:
                agent[key] = val[mask]
        return agent

    def score_nearby_vehicle(self, agent, max_num=10):
        av_index = agent['av_index']
        agent["category"] = torch.zeros_like(agent["category"])
        obstacle_mask = agent['type'] == 3
        pos = agent["position"][av_index, self.num_historical_steps, :2]
        distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1)
        distance[obstacle_mask] = 10e5
        sort_idx = distance.sort()[1]
        nearby_mask = torch.zeros(distance.shape[0])
        nearby_mask[sort_idx[1:max_num]] = 1
        nearby_mask = nearby_mask.bool()
        agent["category"][nearby_mask] = 3
        agent["category"][obstacle_mask] = 0

    def score_trained_vehicle(self, agent, max_num=10, min_distance=0):
        av_index = agent['av_index']
        agent["category"] = torch.zeros_like(agent["category"])
        pos = agent["position"][av_index, self.num_historical_steps, :2]
        distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1)
        distance_all_time = torch.norm(agent["position"][:, :, :2] - agent["position"][av_index, :, :2], dim=-1)
        invalid_mask = distance_all_time < 150  # we do not believe the perception out of range of 150 meters
        agent["valid_mask"] = agent["valid_mask"] * invalid_mask
        # we do not predict vehicle  too far away from ego car
        closet_vehicle = distance < 100
        valid = agent['valid_mask']
        valid_current = valid[:, (self.num_historical_steps):]
        valid_counts = valid_current.sum(1)
        counts_vehicle = valid_counts >= 1
        no_backgroud = agent['type'] != 3
        vehicle2pred = closet_vehicle & counts_vehicle & no_backgroud
        if vehicle2pred.sum() > max_num:
            # too many still vehicle so that train the model using the moving vehicle as much as possible
            true_indices = torch.nonzero(vehicle2pred).squeeze(1)
            selected_indices = true_indices[torch.randperm(true_indices.size(0))[:max_num]]
            vehicle2pred.fill_(False)
            vehicle2pred[selected_indices] = True
        agent["category"][vehicle2pred] = 3

    def rotate_agents(self, position, heading, num_nodes, num_historical_steps, num_future_steps):
        origin = position[:, num_historical_steps - 1]
        theta = heading[:, num_historical_steps - 1]
        cos, sin = theta.cos(), theta.sin()
        rot_mat = theta.new_zeros(num_nodes, 2, 2)
        rot_mat[:, 0, 0] = cos
        rot_mat[:, 0, 1] = -sin
        rot_mat[:, 1, 0] = sin
        rot_mat[:, 1, 1] = cos
        target = origin.new_zeros(num_nodes, num_future_steps, 4)
        target[..., :2] = torch.bmm(position[:, num_historical_steps:, :2] -
                                    origin[:, :2].unsqueeze(1), rot_mat)
        his = origin.new_zeros(num_nodes, num_historical_steps, 4)
        his[..., :2] = torch.bmm(position[:, :num_historical_steps, :2] -
                                 origin[:, :2].unsqueeze(1), rot_mat)
        if position.size(2) == 3:
            target[..., 2] = (position[:, num_historical_steps:, 2] -
                              origin[:, 2].unsqueeze(-1))
            his[..., 2] = (position[:, :num_historical_steps, 2] -
                           origin[:, 2].unsqueeze(-1))
            target[..., 3] = wrap_angle(heading[:, num_historical_steps:] -
                                        theta.unsqueeze(-1))
            his[..., 3] = wrap_angle(heading[:, :num_historical_steps] -
                                     theta.unsqueeze(-1))
        else:
            target[..., 2] = wrap_angle(heading[:, num_historical_steps:] -
                                        theta.unsqueeze(-1))
            his[..., 2] = wrap_angle(heading[:, :num_historical_steps] -
                                     theta.unsqueeze(-1))
        return his, target

    def __call__(self, data) -> HeteroData:
        agent = data["agent"]
        self.score_ego_agent(agent)
        self.score_trained_vehicle(agent, max_num=32)
        return HeteroData(data)


================================================
FILE: smart/utils/__init__.py
================================================

from smart.utils.geometry import angle_between_2d_vectors
from smart.utils.geometry import angle_between_3d_vectors
from smart.utils.geometry import side_to_directed_lineseg
from smart.utils.geometry import wrap_angle
from smart.utils.graph import add_edges
from smart.utils.graph import bipartite_dense_to_sparse
from smart.utils.graph import complete_graph
from smart.utils.graph import merge_edges
from smart.utils.graph import unbatch
from smart.utils.list import safe_list_index
from smart.utils.weight_init import weight_init


================================================
FILE: smart/utils/cluster_reader.py
================================================
import io
import pickle
import pandas as pd
import json


class LoadScenarioFromCeph:
    def __init__(self):
        from petrel_client.client import Client
        self.file_client = Client('~/petreloss.conf')

    def list(self, dir_path):
        return list(self.file_client.list(dir_path))

    def save(self, data, url):
        self.file_client.put(url, pickle.dumps(data))

    def read_correct_csv(self, scenario_path):
        output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python")
        return output

    def contains(self, url):
        return self.file_client.contains(url)

    def read_string(self, csv_url):
        from io import StringIO
        df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False)
        return df

    def read(self, scenario_path):
        with io.BytesIO(self.file_client.get(scenario_path)) as f:
            datas = pickle.load(f)
            return datas

    d
Download .txt
gitextract_98xwsw9y/

├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── configs/
│   ├── train/
│   │   └── train_scalable.yaml
│   └── validation/
│       └── validation_scalable.yaml
├── data_preprocess.py
├── environment.yml
├── pyproject.toml
├── requirements.txt
├── scripts/
│   ├── install_pyg.sh
│   └── traj_clstering.py
├── smart/
│   ├── __init__.py
│   ├── datamodules/
│   │   ├── __init__.py
│   │   └── scalable_datamodule.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── preprocess.py
│   │   └── scalable_dataset.py
│   ├── layers/
│   │   ├── __init__.py
│   │   ├── attention_layer.py
│   │   ├── fourier_embedding.py
│   │   └── mlp_layer.py
│   ├── metrics/
│   │   ├── __init__.py
│   │   ├── average_meter.py
│   │   ├── min_ade.py
│   │   ├── min_fde.py
│   │   ├── next_token_cls.py
│   │   └── utils.py
│   ├── model/
│   │   ├── __init__.py
│   │   └── smart.py
│   ├── modules/
│   │   ├── __init__.py
│   │   ├── agent_decoder.py
│   │   ├── map_decoder.py
│   │   └── smart_decoder.py
│   ├── preprocess/
│   │   ├── __init__.py
│   │   └── preprocess.py
│   ├── tokens/
│   │   ├── __init__.py
│   │   ├── cluster_frame_5_2048.pkl
│   │   └── map_traj_token5.pkl
│   ├── transforms/
│   │   ├── __init__.py
│   │   └── target_builder.py
│   └── utils/
│       ├── __init__.py
│       ├── cluster_reader.py
│       ├── config.py
│       ├── geometry.py
│       ├── graph.py
│       ├── list.py
│       ├── log.py
│       ├── nan_checker.py
│       └── weight_init.py
├── train.py
└── val.py
Download .txt
SYMBOL INDEX (167 symbols across 27 files)

FILE: data_preprocess.py
  function safe_list_index (line 65) | def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
  function get_agent_features (line 72) | def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10,...
  function get_map_features (line 145) | def get_map_features(map_infos, tf_current_light, dim=3):
  function process_agent (line 335) | def process_agent(track_info, tracks_to_predict, sdc_track_index, scenar...
  function process_dynamic_map (line 382) | def process_dynamic_map(dynamic_map_infos):
  function decode_tracks_from_proto (line 466) | def decode_tracks_from_proto(tracks):
  function decode_map_features_from_proto (line 488) | def decode_map_features_from_proto(map_features):
  function decode_dynamic_map_states_from_proto (line 610) | def decode_dynamic_map_states_from_proto(dynamic_map_states):
  function process_single_data (line 630) | def process_single_data(scenario):
  function wm2argo (line 663) | def wm2argo(file, dir_name, output_dir):
  function batch_process9s_transformer (line 693) | def batch_process9s_transformer(dir_name, output_dir, num_workers=2):

FILE: scripts/traj_clstering.py
  function average_distance_vectorized (line 5) | def average_distance_vectorized(point_set1, centroids):
  function assign_clusters (line 10) | def assign_clusters(sub_X, centroids):
  function Kdisk_cluster (line 15) | def Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None):
  function cal_polygon_contour (line 58) | def cal_polygon_contour(x, y, theta, width, length):

FILE: smart/datamodules/scalable_datamodule.py
  class MultiDataModule (line 9) | class MultiDataModule(pl.LightningDataModule):
    method __init__ (line 18) | def __init__(self,
    method setup (line 69) | def setup(self, stage: Optional[str] = None) -> None:
    method train_dataloader (line 77) | def train_dataloader(self):
    method val_dataloader (line 82) | def val_dataloader(self):
    method test_dataloader (line 87) | def test_dataloader(self):

FILE: smart/datasets/preprocess.py
  function cal_polygon_contour (line 10) | def cal_polygon_contour(x, y, theta, width, length):
  function interplating_polyline (line 33) | def interplating_polyline(polylines, heading, distance=0.5, split_distac...
  function average_distance_vectorized (line 117) | def average_distance_vectorized(point_set1, centroids):
  function assign_clusters (line 122) | def assign_clusters(sub_X, centroids):
  class TokenProcessor (line 127) | class TokenProcessor:
    method __init__ (line 129) | def __init__(self, token_size):
    method preprocess (line 140) | def preprocess(self, data):
    method get_trajectory_token (line 150) | def get_trajectory_token(self):
    method clean_heading (line 171) | def clean_heading(self, data):
    method tokenize_agent (line 195) | def tokenize_agent(self, data):
    method match_token (line 295) | def match_token(self, pos, valid_mask, heading, category, agent_catego...
    method tokenize_map (line 403) | def tokenize_map(self, data):

FILE: smart/datasets/scalable_dataset.py
  function distance (line 11) | def distance(point1, point2):
  class MultiDataset (line 15) | class MultiDataset(Dataset):
    method __init__ (line 16) | def __init__(self,
    method raw_dir (line 66) | def raw_dir(self) -> str:
    method raw_paths (line 70) | def raw_paths(self) -> List[str]:
    method raw_file_names (line 74) | def raw_file_names(self) -> Union[str, List[str], Tuple]:
    method processed_file_names (line 78) | def processed_file_names(self) -> Union[str, List[str], Tuple]:
    method len (line 81) | def len(self) -> int:
    method generate_ref_token (line 84) | def generate_ref_token(self):
    method get (line 87) | def get(self, idx: int):

FILE: smart/layers/attention_layer.py
  class AttentionLayer (line 12) | class AttentionLayer(MessagePassing):
    method __init__ (line 14) | def __init__(self,
    method forward (line 57) | def forward(self,
    method message (line 74) | def message(self,
    method update (line 90) | def update(self,
    method _attn_block (line 97) | def _attn_block(self,
    method _ff_block (line 108) | def _ff_block(self, x: torch.Tensor) -> torch.Tensor:

FILE: smart/layers/fourier_embedding.py
  class FourierEmbedding (line 9) | class FourierEmbedding(nn.Module):
    method __init__ (line 11) | def __init__(self,
    method forward (line 35) | def forward(self,
  class MLPEmbedding (line 56) | class MLPEmbedding(nn.Module):
    method __init__ (line 57) | def __init__(self,
    method forward (line 73) | def forward(self,

FILE: smart/layers/mlp_layer.py
  class MLPLayer (line 8) | class MLPLayer(nn.Module):
    method __init__ (line 10) | def __init__(self,
    method forward (line 23) | def forward(self, x: torch.Tensor) -> torch.Tensor:

FILE: smart/metrics/average_meter.py
  class AverageMeter (line 6) | class AverageMeter(Metric):
    method __init__ (line 8) | def __init__(self, **kwargs) -> None:
    method update (line 13) | def update(self, val: torch.Tensor) -> None:
    method compute (line 17) | def compute(self) -> torch.Tensor:

FILE: smart/metrics/min_ade.py
  class minMultiADE (line 11) | class minMultiADE(Metric):
    method __init__ (line 13) | def __init__(self,
    method update (line 21) | def update(self,
    method compute (line 44) | def compute(self) -> torch.Tensor:
  class minADE (line 48) | class minADE(Metric):
    method __init__ (line 50) | def __init__(self,
    method update (line 59) | def update(self,
    method compute (line 84) | def compute(self) -> torch.Tensor:

FILE: smart/metrics/min_fde.py
  class minMultiFDE (line 10) | class minMultiFDE(Metric):
    method __init__ (line 12) | def __init__(self,
    method update (line 20) | def update(self,
    method compute (line 34) | def compute(self) -> torch.Tensor:
  class minFDE (line 38) | class minFDE(Metric):
    method __init__ (line 40) | def __init__(self,
    method update (line 49) | def update(self,
    method compute (line 60) | def compute(self) -> torch.Tensor:

FILE: smart/metrics/next_token_cls.py
  class TokenCls (line 10) | class TokenCls(Metric):
    method __init__ (line 12) | def __init__(self,
    method update (line 20) | def update(self,
    method compute (line 29) | def compute(self) -> torch.Tensor:

FILE: smart/metrics/utils.py
  function topk (line 8) | def topk(
  function topkind (line 44) | def topkind(
  function valid_filter (line 80) | def valid_filter(
  function new_batch_nms (line 108) | def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):
  function batch_nms (line 163) | def batch_nms(pred_trajs, pred_scores,
  function batch_nms_token (line 224) | def batch_nms_token(pred_trajs, pred_scores,

FILE: smart/model/smart.py
  function cal_polygon_contour (line 20) | def cal_polygon_contour(x, y, theta, width, length):
  function joint_scene_from_states (line 41) | def joint_scene_from_states(states, object_ids) -> sim_agents_submission...
  class SMART (line 53) | class SMART(pl.LightningModule):
    method __init__ (line 55) | def __init__(self, model_config) -> None:
    method get_trajectory_token (line 106) | def get_trajectory_token(self):
    method init_map_token (line 113) | def init_map_token(self):
    method forward (line 124) | def forward(self, data: HeteroData):
    method inference (line 128) | def inference(self, data: HeteroData):
    method maybe_autocast (line 132) | def maybe_autocast(self, dtype=torch.float16):
    method training_step (line 140) | def training_step(self,
    method validation_step (line 157) | def validation_step(self,
    method on_validation_start (line 199) | def on_validation_start(self):
    method configure_optimizers (line 205) | def configure_optimizers(self):
    method load_params_from_file (line 218) | def load_params_from_file(self, filename, logger, to_cpu=False):
    method match_token_map (line 257) | def match_token_map(self, data):
    method sample_pt_pred (line 321) | def sample_pt_pred(self, data):

FILE: smart/modules/agent_decoder.py
  function cal_polygon_contour (line 15) | def cal_polygon_contour(x, y, theta, width, length):
  class SMARTAgentDecoder (line 36) | class SMARTAgentDecoder(nn.Module):
    method __init__ (line 38) | def __init__(self,
    method transform_rel (line 110) | def transform_rel(self, token_traj, prev_pos, prev_heading=None):
    method agent_token_embedding (line 126) | def agent_token_embedding(self, data, agent_category, agent_token_inde...
    method agent_predict_next (line 194) | def agent_predict_next(self, data, agent_category, feat_a):
    method agent_predict_next_inf (line 206) | def agent_predict_next_inf(self, data, agent_category, feat_a):
    method build_temporal_edge (line 221) | def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent,...
    method build_interaction_edge (line 249) | def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s...
    method build_map2agent_edge (line 265) | def build_map2agent_edge(self, data, num_step, agent_category, pos_a, ...
    method forward (line 288) | def forward(self,
    method inference (line 351) | def inference(self,

FILE: smart/modules/map_decoder.py
  class SMARTMapDecoder (line 20) | class SMARTMapDecoder(nn.Module):
    method __init__ (line 22) | def __init__(self,
    method maybe_autocast (line 73) | def maybe_autocast(self, dtype=torch.float32):
    method forward (line 76) | def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:

FILE: smart/modules/smart_decoder.py
  class SMARTDecoder (line 9) | class SMARTDecoder(nn.Module):
    method __init__ (line 11) | def __init__(self,
    method forward (line 62) | def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
    method inference (line 67) | def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:
    method inference_no_map (line 72) | def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, tor...

FILE: smart/preprocess/preprocess.py
  function get_agent_features (line 44) | def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10,...

FILE: smart/transforms/target_builder.py
  function to_16 (line 10) | def to_16(data):
  function tofloat32 (line 21) | def tofloat32(data):
  class WaymoTargetBuilder (line 32) | class WaymoTargetBuilder(BaseTransform):
    method __init__ (line 34) | def __init__(self,
    method score_ego_agent (line 45) | def score_ego_agent(self, agent):
    method clip (line 50) | def clip(self, agent, max_num=32):
    method score_nearby_vehicle (line 79) | def score_nearby_vehicle(self, agent, max_num=10):
    method score_trained_vehicle (line 93) | def score_trained_vehicle(self, agent, max_num=10, min_distance=0):
    method rotate_agents (line 117) | def rotate_agents(self, position, heading, num_nodes, num_historical_s...
    method __call__ (line 148) | def __call__(self, data) -> HeteroData:

FILE: smart/utils/cluster_reader.py
  class LoadScenarioFromCeph (line 7) | class LoadScenarioFromCeph:
    method __init__ (line 8) | def __init__(self):
    method list (line 12) | def list(self, dir_path):
    method save (line 15) | def save(self, data, url):
    method read_correct_csv (line 18) | def read_correct_csv(self, scenario_path):
    method contains (line 22) | def contains(self, url):
    method read_string (line 25) | def read_string(self, csv_url):
    method read (line 30) | def read(self, scenario_path):
    method read_json (line 35) | def read_json(self, path):
    method read_csv (line 40) | def read_csv(self, scenario_path):
    method read_model (line 43) | def read_model(self, model_path):

FILE: smart/utils/config.py
  function load_config_act (line 6) | def load_config_act(path):
  function load_config_init (line 13) | def load_config_init(path):

FILE: smart/utils/geometry.py
  function angle_between_2d_vectors (line 7) | def angle_between_2d_vectors(
  function angle_between_3d_vectors (line 14) | def angle_between_3d_vectors(
  function side_to_directed_lineseg (line 21) | def side_to_directed_lineseg(
  function wrap_angle (line 35) | def wrap_angle(

FILE: smart/utils/graph.py
  function add_edges (line 9) | def add_edges(
  function merge_edges (line 33) | def merge_edges(
  function complete_graph (line 45) | def complete_graph(
  function bipartite_dense_to_sparse (line 76) | def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor:
  function unbatch (line 85) | def unbatch(

FILE: smart/utils/list.py
  function safe_list_index (line 5) | def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:

FILE: smart/utils/log.py
  class Logging (line 6) | class Logging:
    method make_log_dir (line 8) | def make_log_dir(self, dirname='logs'):
    method get_log_filename (line 16) | def get_log_filename(self):
    method log (line 22) | def log(self, level='DEBUG', name="simagent"):
    method add_log (line 36) | def add_log(self, logger, level='DEBUG'):

FILE: smart/utils/nan_checker.py
  function check_nan_inf (line 3) | def check_nan_inf(t, s):

FILE: smart/utils/weight_init.py
  function weight_init (line 5) | def weight_init(m: nn.Module) -> None:
Condensed preview — 52 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (218K chars).
[
  {
    "path": ".gitignore",
    "chars": 1576,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n.github\nckpt/\n# assets/\n# C extensions\n*.so\n# "
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 4930,
    "preview": "<div align=\"center\">\n  \n  # SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction\n  \n  [Pape"
  },
  {
    "path": "__init__.py",
    "chars": 1,
    "preview": "\n"
  },
  {
    "path": "configs/train/train_scalable.yaml",
    "chars": 1215,
    "preview": "# Config format schema number, the yaml support to valid case source from different dataset\ntime_info: &time_info\n  num_"
  },
  {
    "path": "configs/validation/validation_scalable.yaml",
    "chars": 1138,
    "preview": "# Config format schema number, the yaml support to valid case source from different dataset\ntime_info: &time_info\n  num_"
  },
  {
    "path": "data_preprocess.py",
    "chars": 33206,
    "preview": "import numpy as np\nimport pandas as pd\nimport os\nimport torch\nimport pickle\nfrom tqdm import tqdm\nfrom typing import Any"
  },
  {
    "path": "environment.yml",
    "chars": 2249,
    "preview": "name: smart\nchannels:\n  - pytorch\n  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free\n  - https://mirrors.tuna.t"
  },
  {
    "path": "pyproject.toml",
    "chars": 950,
    "preview": "[build-system]\nrequires = [\"setuptools>=42\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"smart\"\n"
  },
  {
    "path": "requirements.txt",
    "chars": 813,
    "preview": "aiohappyeyeballs==2.4.3\naiohttp==3.10.10\naiosignal==1.3.1\nasync-timeout==4.0.3\nattrs==24.2.0\ncontourpy==1.3.0\ncycler==0."
  },
  {
    "path": "scripts/install_pyg.sh",
    "chars": 849,
    "preview": "mkdir pyg_depend && cd pyg_depend\nwget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp"
  },
  {
    "path": "scripts/traj_clstering.py",
    "chars": 5870,
    "preview": "from smart.utils.geometry import wrap_angle\nimport numpy as np\n\n\ndef average_distance_vectorized(point_set1, centroids):"
  },
  {
    "path": "smart/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "smart/datamodules/__init__.py",
    "chars": 66,
    "preview": "from smart.datamodules.scalable_datamodule import MultiDataModule\n"
  },
  {
    "path": "smart/datamodules/scalable_datamodule.py",
    "chars": 4526,
    "preview": "from typing import Optional\n\nimport pytorch_lightning as pl\nfrom torch_geometric.loader import DataLoader\nfrom smart.dat"
  },
  {
    "path": "smart/datasets/__init__.py",
    "chars": 57,
    "preview": "from smart.datasets.scalable_dataset import MultiDataset\n"
  },
  {
    "path": "smart/datasets/preprocess.py",
    "chars": 24048,
    "preview": "import torch\nimport numpy as np\nfrom scipy.interpolate import interp1d\nfrom scipy.spatial.distance import euclidean\nimpo"
  },
  {
    "path": "smart/datasets/scalable_dataset.py",
    "chars": 3406,
    "preview": "import os\nimport pickle\nfrom typing import Callable, List, Optional, Tuple, Union\nimport pandas as pd\nfrom torch_geometr"
  },
  {
    "path": "smart/layers/__init__.py",
    "chars": 175,
    "preview": "\nfrom smart.layers.attention_layer import AttentionLayer\nfrom smart.layers.fourier_embedding import FourierEmbedding, ML"
  },
  {
    "path": "smart/layers/attention_layer.py",
    "chars": 4503,
    "preview": "\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.nn.conv import Messa"
  },
  {
    "path": "smart/layers/fourier_embedding.py",
    "chars": 3294,
    "preview": "import math\nfrom typing import List, Optional\nimport torch\nimport torch.nn as nn\n\nfrom smart.utils import weight_init\n\n\n"
  },
  {
    "path": "smart/layers/mlp_layer.py",
    "chars": 603,
    "preview": "\nimport torch\nimport torch.nn as nn\n\nfrom smart.utils import weight_init\n\n\nclass MLPLayer(nn.Module):\n\n    def __init__("
  },
  {
    "path": "smart/metrics/__init__.py",
    "chars": 186,
    "preview": "\nfrom smart.metrics.average_meter import AverageMeter\nfrom smart.metrics.min_ade import minADE\nfrom smart.metrics.min_fd"
  },
  {
    "path": "smart/metrics/average_meter.py",
    "chars": 521,
    "preview": "\nimport torch\nfrom torchmetrics import Metric\n\n\nclass AverageMeter(Metric):\n\n    def __init__(self, **kwargs) -> None:\n "
  },
  {
    "path": "smart/metrics/min_ade.py",
    "chars": 4174,
    "preview": "\nfrom typing import Optional\n\nimport torch\nfrom torchmetrics import Metric\n\nfrom smart.metrics.utils import topk\nfrom sm"
  },
  {
    "path": "smart/metrics/min_fde.py",
    "chars": 2528,
    "preview": "from typing import Optional\n\nimport torch\nfrom torchmetrics import Metric\n\nfrom smart.metrics.utils import topk\nfrom sma"
  },
  {
    "path": "smart/metrics/next_token_cls.py",
    "chars": 947,
    "preview": "from typing import Optional\n\nimport torch\nfrom torchmetrics import Metric\n\nfrom smart.metrics.utils import topk\nfrom sma"
  },
  {
    "path": "smart/metrics/utils.py",
    "chars": 12650,
    "preview": "from typing import Optional, Tuple\n\nimport torch\nfrom torch_scatter import gather_csr\nfrom torch_scatter import segment_"
  },
  {
    "path": "smart/model/__init__.py",
    "chars": 36,
    "preview": "from smart.model.smart import SMART\n"
  },
  {
    "path": "smart/model/smart.py",
    "chars": 17228,
    "preview": "import contextlib\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.data import Bat"
  },
  {
    "path": "smart/modules/__init__.py",
    "chars": 165,
    "preview": "from smart.modules.smart_decoder import SMARTDecoder\nfrom smart.modules.map_decoder import SMARTMapDecoder\nfrom smart.mo"
  },
  {
    "path": "smart/modules/agent_decoder.py",
    "chars": 28824,
    "preview": "import pickle\nfrom typing import Dict, Mapping, Optional\nimport torch\nimport torch.nn as nn\nfrom smart.layers import MLP"
  },
  {
    "path": "smart/modules/map_decoder.py",
    "chars": 6507,
    "preview": "import os.path\nfrom typing import Dict\nimport torch\nimport torch.nn as nn\nfrom torch_cluster import radius_graph\nfrom to"
  },
  {
    "path": "smart/modules/smart_decoder.py",
    "chars": 2638,
    "preview": "from typing import Dict, Optional\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.data import HeteroData\nfrom sm"
  },
  {
    "path": "smart/preprocess/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "smart/preprocess/preprocess.py",
    "chars": 5138,
    "preview": "import numpy as np\nimport pandas as pd\nimport os\nimport torch\nfrom typing import Any, Dict, List, Optional\n\npredict_unse"
  },
  {
    "path": "smart/tokens/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "smart/transforms/__init__.py",
    "chars": 63,
    "preview": "from smart.transforms.target_builder import WaymoTargetBuilder\n"
  },
  {
    "path": "smart/transforms/target_builder.py",
    "chars": 6622,
    "preview": "\nimport numpy as np\nimport torch\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.transforms import Base"
  },
  {
    "path": "smart/utils/__init__.py",
    "chars": 533,
    "preview": "\nfrom smart.utils.geometry import angle_between_2d_vectors\nfrom smart.utils.geometry import angle_between_3d_vectors\nfro"
  },
  {
    "path": "smart/utils/cluster_reader.py",
    "chars": 1370,
    "preview": "import io\nimport pickle\nimport pandas as pd\nimport json\n\n\nclass LoadScenarioFromCeph:\n    def __init__(self):\n        fr"
  },
  {
    "path": "smart/utils/config.py",
    "chars": 422,
    "preview": "import os\nimport yaml\nimport easydict\n\n\ndef load_config_act(path):\n    \"\"\" load config file\"\"\"\n    with open(path, 'r') "
  },
  {
    "path": "smart/utils/geometry.py",
    "chars": 1210,
    "preview": "\nimport math\n\nimport torch\n\n\ndef angle_between_2d_vectors(\n        ctr_vector: torch.Tensor,\n        nbr_vector: torch.T"
  },
  {
    "path": "smart/utils/graph.py",
    "chars": 3942,
    "preview": "\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch_geometric.utils import coalesce\nfrom torch_geo"
  },
  {
    "path": "smart/utils/list.py",
    "chars": 188,
    "preview": "\nfrom typing import Any, List, Optional\n\n\ndef safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:\n    try:\n     "
  },
  {
    "path": "smart/utils/log.py",
    "chars": 2043,
    "preview": "import logging\nimport time\nimport os\n\n\nclass Logging:\n\n    def make_log_dir(self, dirname='logs'):\n        now_dir = os."
  },
  {
    "path": "smart/utils/nan_checker.py",
    "chars": 150,
    "preview": "import torch\n\ndef check_nan_inf(t, s):\n    assert not torch.isinf(t).any(), f\"{s} is inf, {t}\"\n    assert not torch.isna"
  },
  {
    "path": "smart/utils/weight_init.py",
    "chars": 2826,
    "preview": "\nimport torch.nn as nn\n\n\ndef weight_init(m: nn.Module) -> None:\n    if isinstance(m, nn.Linear):\n        nn.init.xavier_"
  },
  {
    "path": "train.py",
    "chars": 2539,
    "preview": "\nfrom argparse import ArgumentParser\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import LearningRate"
  },
  {
    "path": "val.py",
    "chars": 2039,
    "preview": "\nfrom argparse import ArgumentParser\nimport pytorch_lightning as pl\nfrom torch_geometric.loader import DataLoader\nfrom s"
  }
]

// ... and 2 more files (download for full content)

About this extraction

This page contains the full source code of the rainmaker22/SMART GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 52 files (205.4 KB), approximately 54.1k tokens, and a symbol index with 167 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!