Showing preview only (219K chars total). Download the full file or copy to clipboard to get everything.
Repository: rainmaker22/SMART
Branch: main
Commit: 42e658542b03
Files: 52
Total size: 205.4 KB
Directory structure:
gitextract_98xwsw9y/
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── configs/
│ ├── train/
│ │ └── train_scalable.yaml
│ └── validation/
│ └── validation_scalable.yaml
├── data_preprocess.py
├── environment.yml
├── pyproject.toml
├── requirements.txt
├── scripts/
│ ├── install_pyg.sh
│ └── traj_clstering.py
├── smart/
│ ├── __init__.py
│ ├── datamodules/
│ │ ├── __init__.py
│ │ └── scalable_datamodule.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── preprocess.py
│ │ └── scalable_dataset.py
│ ├── layers/
│ │ ├── __init__.py
│ │ ├── attention_layer.py
│ │ ├── fourier_embedding.py
│ │ └── mlp_layer.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── average_meter.py
│ │ ├── min_ade.py
│ │ ├── min_fde.py
│ │ ├── next_token_cls.py
│ │ └── utils.py
│ ├── model/
│ │ ├── __init__.py
│ │ └── smart.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── agent_decoder.py
│ │ ├── map_decoder.py
│ │ └── smart_decoder.py
│ ├── preprocess/
│ │ ├── __init__.py
│ │ └── preprocess.py
│ ├── tokens/
│ │ ├── __init__.py
│ │ ├── cluster_frame_5_2048.pkl
│ │ └── map_traj_token5.pkl
│ ├── transforms/
│ │ ├── __init__.py
│ │ └── target_builder.py
│ └── utils/
│ ├── __init__.py
│ ├── cluster_reader.py
│ ├── config.py
│ ├── geometry.py
│ ├── graph.py
│ ├── list.py
│ ├── log.py
│ ├── nan_checker.py
│ └── weight_init.py
├── train.py
└── val.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.github
ckpt/
# assets/
# C extensions
*.so
# /assets
/data
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
*.jpg
env/
venv/
ENV/
env.bak/
venv.bak/
*.jpg
pyg_depend/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# IDEs
.idea
.vscode
# seed project
av2/
lightning_logs/
lightning_logs_/
lightning_l/
.DS_Store
data/argo
data/res
data/waymo*
fig*/
data/waymo_token
data/submission
data/token_seq_emb_nuplan
data/token_seq_emb_waymo
data/nuplan*
submission.tar.gz
data/feat*
data/scalable
data/pos_data
res_metrics*
gathered*
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<div align="center">
# SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction
[Paper](https://arxiv.org/abs/2405.15677) | [Webpage](https://smart-motion.github.io/smart/)
</div>
- **Ranked 1st** on the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/)
- **Champion** of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)
## News
- **[December 31, 2024]** SMART-Planner achieved state-of-the-art performance on **nuPlan closed-loop planning**
- **[September 26, 2024]** SMART was **accepted to** NeurIPS 2024
- **[August 31, 2024]** Code released
- **[May 24, 2024]** SMART won the championship of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)
- **[May 24, 2024]** SMART paper released on [arxiv](https://arxiv.org/abs/2405.15677)
## Introduction
This repository contains the official implementation of SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction. SMART is a novel autonomous driving motion generation paradigm that models vectorized map and agent trajectory data into discrete sequence tokens.
https://github.com/user-attachments/assets/74a61627-8444-4e54-bb10-d317dd2aacd9
## Requirements
To set up the environment, you can use conda to create and activate a new environment with the necessary dependencies:
```bash
conda env create -f environment.yml
conda activate SMART
pip install -r requirements.txt
```
If you encounter issues while installing pyg dependencies, execute the following script:
```setup
bash install_pyg.sh
```
Alternatively, you can configure the environment in your preferred way. Installing the latest versions of PyTorch, PyG, and PyTorch Lightning should suffice.
## Data installation
**Step 1: Download the Dataset**
Download the Waymo Open Motion Dataset (`scenario protocol` format) and organize the data as follows:
```
SMART
├── data
│ ├── waymo
│ │ ├── scenario
│ │ │ ├──training
│ │ │ ├──validation
│ │ │ ├──testing
├── model
├── tools
```
**Step 2: Install the Waymo Open Dataset API**
Follow the instructions [here](https://github.com/waymo-research/waymo-open-dataset) to install the Waymo Open Dataset API.
**Step 3: Preprocess the Dataset**
Preprocess the dataset by running:
```
python data_preprocess.py --input_dir ./data/waymo/scenario/training --output_dir ./data/waymo_processed/training
```
The first path is the raw data path, and the second is the output data path.
The processed data will be saved to the `data/waymo_processed/` directory as follows:
```
SMART
├── data
│ ├── waymo_processed
│ │ ├── training
│ │ ├── validation
│ │ ├──testing
├── model
├── utils
```
## Training
To train the model, run the following command:
```train
python train.py --config ${config_path}
```
The default config path is `configs/train/train_scalable.yaml`. Ensure you have downloaded and prepared the Waymo data for training.
## Evaluation
To evaluate the model, run:
```eval
python eval.py --config ${config_path} --pretrain_ckpt ${ckpt_path}
```
This will evaluate the model using the configuration and checkpoint provided.
## Pre-trained Models
To comply with the WOMD participation agreement, we will release the model parameters of a medium-sized model not trained on Waymo data. Users can fine-tune this model with Waymo data as needed.
## Results
### Waymo Open Motion Dataset Sim Agents Challenge
Our model achieves the following performance on the [Waymo Open Motion Dataset Sim Agents Challenge](https://waymo.com/open/challenges/2024/sim-agents/):
| Model name | Metric Score |
| :-----------: | ------------ |
| SMART-tiny | 0.7591 |
| SMART-large | 0.7614 |
| SMART-zeroshot| 0.7210 |
### NuPlan Closed-loop Planning
**SMART-Planner** achieved state-of-the-art performance among learning-based algorithms on **nuPlan closed-loop planning**. The results on val14 are shown below:

## Citation
If you find this repository useful, please consider citing our work and giving us a star:
```citation
@article{wu2024smart,
title={SMART: Scalable Multi-agent Real-time Simulation via Next-token Prediction},
author={Wu, Wei and Feng, Xiaoxin and Gao, Ziyan and Kan, Yuheng},
journal={arXiv preprint arXiv:2405.15677},
year={2024}
}
```
## Acknowledgements
Special thanks to the [QCNET](https://github.com/ZikangZhou/QCNet) repository for providing valuable reference code that significantly influenced this work.
## License
All code in this repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
================================================
FILE: __init__.py
================================================
================================================
FILE: configs/train/train_scalable.yaml
================================================
# Config format schema number, the yaml support to valid case source from different dataset
time_info: &time_info
num_historical_steps: 11
num_future_steps: 80
use_intention: True
token_size: 2048
Dataset:
root:
train_batch_size: 1
val_batch_size: 1
test_batch_size: 1
shuffle: True
num_workers: 1
pin_memory: True
persistent_workers: True
train_raw_dir: ["data/valid_demo"]
val_raw_dir: ["data/valid_demo"]
test_raw_dir:
transform: WaymoTargetBuilder
train_processed_dir:
val_processed_dir:
test_processed_dir:
dataset: "scalable"
<<: *time_info
Trainer:
strategy: ddp_find_unused_parameters_false
accelerator: "gpu"
devices: 1
max_epochs: 32
save_ckpt_path:
num_nodes: 1
mode:
ckpt_path:
precision: 32
accumulate_grad_batches: 1
Model:
mode: "train"
predictor: "smart"
dataset: "waymo"
input_dim: 2
hidden_dim: 128
output_dim: 2
output_head: False
num_heads: 8
<<: *time_info
head_dim: 16
dropout: 0.1
num_freq_bands: 64
lr: 0.0005
warmup_steps: 0
total_steps: 32
decoder:
<<: *time_info
num_map_layers: 3
num_agent_layers: 6
a2a_radius: 60
pl2pl_radius: 10
pl2a_radius: 30
time_span: 30
================================================
FILE: configs/validation/validation_scalable.yaml
================================================
# Config format schema number, the yaml support to valid case source from different dataset
time_info: &time_info
num_historical_steps: 11
num_future_steps: 80
token_size: 2048
Dataset:
root:
batch_size: 1
shuffle: True
num_workers: 1
pin_memory: True
persistent_workers: True
train_raw_dir:
val_raw_dir: ["data/valid_demo"]
test_raw_dir:
TargetBuilder: WaymoTargetBuilder
train_processed_dir:
val_processed_dir:
test_processed_dir:
dataset: "scalable"
<<: *time_info
Trainer:
strategy: ddp_find_unused_parameters_false
accelerator: "gpu"
devices: 1
max_epochs: 32
save_ckpt_path:
num_nodes: 1
mode:
ckpt_path:
precision: 32
accumulate_grad_batches: 1
Model:
mode: "validation"
predictor: "smart"
dataset: "waymo"
input_dim: 2
hidden_dim: 128
output_dim: 2
output_head: False
num_heads: 8
<<: *time_info
head_dim: 16
dropout: 0.1
num_freq_bands: 64
lr: 0.0005
warmup_steps: 0
total_steps: 32
decoder:
<<: *time_info
num_map_layers: 3
num_agent_layers: 6
a2a_radius: 60
pl2pl_radius: 10
pl2a_radius: 30
time_span: 30
================================================
FILE: data_preprocess.py
================================================
import numpy as np
import pandas as pd
import os
import torch
import pickle
from tqdm import tqdm
from typing import Any, Dict, List, Optional
import easydict
predict_unseen_agents = False
vector_repr = True
root = ''
split = 'train'
raw_dir = os.path.join(root, split, 'raw')
_raw_dir = raw_dir
if os.path.isdir(_raw_dir):
_raw_file_names = [name for name in os.listdir(_raw_dir)]
else:
_raw_file_names = []
processed_dir = os.path.join(root, split, 'processed')
_processed_dir = processed_dir
if os.path.isdir(_processed_dir):
_processed_file_names = [name for name in os.listdir(_processed_dir) if
name.endswith(('pkl', 'pickle'))]
else:
_processed_file_names = []
_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']
_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
_point_sides = ['LEFT', 'RIGHT', 'CENTER']
_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
_polygon_is_intersections = [True, False, None]
Lane_type_hash = {
4: "BIKE",
3: "VEHICLE",
2: "VEHICLE",
1: "BUS"
}
boundary_type_hash = {
5: "UNKNOWN",
6: "DASHED_WHITE",
7: "SOLID_WHITE",
8: "DOUBLE_DASH_WHITE",
9: "DASHED_YELLOW",
10: "DOUBLE_DASH_YELLOW",
11: "SOLID_YELLOW",
12: "DOUBLE_SOLID_YELLOW",
13: "DASH_SOLID_YELLOW",
14: "UNKNOWN",
15: "EDGE",
16: "EDGE"
}
def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
try:
return ls.index(elem)
except ValueError:
return None
def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:
if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps
historical_df = df[df['timestep'] == num_historical_steps-1]
agent_ids = list(historical_df['track_id'].unique())
df = df[df['track_id'].isin(agent_ids)]
else:
agent_ids = list(df['track_id'].unique())
num_agents = len(agent_ids)
# initialization
valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
agent_id: List[Optional[str]] = [None] * num_agents
agent_type = torch.zeros(num_agents, dtype=torch.uint8)
agent_category = torch.zeros(num_agents, dtype=torch.uint8)
position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
for track_id, track_df in df.groupby('track_id'):
agent_idx = agent_ids.index(track_id)
agent_steps = track_df['timestep'].values
valid_mask[agent_idx, agent_steps] = True
current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]
predict_mask[agent_idx, agent_steps] = True
if vector_repr: # a time step t is valid only when both t and t-1 are valid
valid_mask[agent_idx, 1: num_historical_steps] = (
valid_mask[agent_idx, :num_historical_steps - 1] &
valid_mask[agent_idx, 1: num_historical_steps])
valid_mask[agent_idx, 0] = False
predict_mask[agent_idx, :num_historical_steps] = False
if not current_valid_mask[agent_idx]:
predict_mask[agent_idx, num_historical_steps:] = False
agent_id[agent_idx] = track_id
agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
agent_category[agent_idx] = track_df['object_category'].values[0]
position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,
track_df['position_y'].values,
track_df['position_z'].values],
axis=-1)).float()
heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()
velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,
track_df['velocity_y'].values],
axis=-1)).float()
shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,
track_df['width'].values,
track_df["height"].values],
axis=-1)).float()
av_idx = agent_id.index(av_id)
if split == 'test':
predict_mask[current_valid_mask
| (agent_category == 2)
| (agent_category == 3), num_historical_steps:] = True
return {
'num_nodes': num_agents,
'av_index': av_idx,
'valid_mask': valid_mask,
'predict_mask': predict_mask,
'id': agent_id,
'type': agent_type,
'category': agent_category,
'position': position,
'heading': heading,
'velocity': velocity,
'shape': shape
}
def get_map_features(map_infos, tf_current_light, dim=3):
lane_segments = map_infos['lane']
all_polylines = map_infos["all_polylines"]
crosswalks = map_infos['crosswalk']
road_edges = map_infos['road_edge']
road_lines = map_infos['road_line']
lane_segment_ids = [info["id"] for info in lane_segments]
cross_walk_ids = [info["id"] for info in crosswalks]
road_edge_ids = [info["id"] for info in road_edges]
road_line_ids = [info["id"] for info in road_lines]
polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids
num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids)
# initialization
polygon_type = torch.zeros(num_polygons, dtype=torch.uint8)
polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3
point_position: List[Optional[torch.Tensor]] = [None] * num_polygons
point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons
point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons
point_height: List[Optional[torch.Tensor]] = [None] * num_polygons
point_type: List[Optional[torch.Tensor]] = [None] * num_polygons
for lane_segment in lane_segments:
lane_segment = easydict.EasyDict(lane_segment)
lane_segment_idx = polygon_ids.index(lane_segment.id)
polyline_index = lane_segment.polyline_index
centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
centerline = torch.from_numpy(centerline).float()
polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type])
res = tf_current_light[tf_current_light["lane_id"] == str(lane_segment.id)]
if len(res) != 0:
polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res["state"].item())
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
center_vectors = centerline[1:] - centerline[:-1]
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
center_type = _point_types.index('CENTERLINE')
point_type[lane_segment_idx] = torch.cat(
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
for lane_segment in road_edges:
lane_segment = easydict.EasyDict(lane_segment)
lane_segment_idx = polygon_ids.index(lane_segment.id)
polyline_index = lane_segment.polyline_index
centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
centerline = torch.from_numpy(centerline).float()
polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
center_vectors = centerline[1:] - centerline[:-1]
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
center_type = _point_types.index('EDGE')
point_type[lane_segment_idx] = torch.cat(
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
for lane_segment in road_lines:
lane_segment = easydict.EasyDict(lane_segment)
lane_segment_idx = polygon_ids.index(lane_segment.id)
polyline_index = lane_segment.polyline_index
centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
centerline = torch.from_numpy(centerline).float()
polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
center_vectors = centerline[1:] - centerline[:-1]
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
center_type = _point_types.index(boundary_type_hash[lane_segment.type])
point_type[lane_segment_idx] = torch.cat(
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
for crosswalk in crosswalks:
crosswalk = easydict.EasyDict(crosswalk)
lane_segment_idx = polygon_ids.index(crosswalk.id)
polyline_index = crosswalk.polyline_index
centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
centerline = torch.from_numpy(centerline).float()
polygon_type[lane_segment_idx] = _polygon_types.index("PEDESTRIAN")
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
center_vectors = centerline[1:] - centerline[:-1]
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
center_type = _point_types.index("CROSSWALK")
point_type[lane_segment_idx] = torch.cat(
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long)
point_to_polygon_edge_index = torch.stack(
[torch.arange(num_points.sum(), dtype=torch.long),
torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0)
polygon_to_polygon_edge_index = []
polygon_to_polygon_type = []
for lane_segment in lane_segments:
lane_segment = easydict.EasyDict(lane_segment)
lane_segment_idx = polygon_ids.index(lane_segment.id)
pred_inds = []
for pred in lane_segment.entry_lanes:
pred_idx = safe_list_index(polygon_ids, pred)
if pred_idx is not None:
pred_inds.append(pred_idx)
if len(pred_inds) != 0:
polygon_to_polygon_edge_index.append(
torch.stack([torch.tensor(pred_inds, dtype=torch.long),
torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
polygon_to_polygon_type.append(
torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8))
succ_inds = []
for succ in lane_segment.exit_lanes:
succ_idx = safe_list_index(polygon_ids, succ)
if succ_idx is not None:
succ_inds.append(succ_idx)
if len(succ_inds) != 0:
polygon_to_polygon_edge_index.append(
torch.stack([torch.tensor(succ_inds, dtype=torch.long),
torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
polygon_to_polygon_type.append(
torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8))
if len(lane_segment.left_neighbors) != 0:
left_neighbor_ids = lane_segment.left_neighbors
for left_neighbor_id in left_neighbor_ids:
left_idx = safe_list_index(polygon_ids, left_neighbor_id)
if left_idx is not None:
polygon_to_polygon_edge_index.append(
torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long))
polygon_to_polygon_type.append(
torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8))
if len(lane_segment.right_neighbors) != 0:
right_neighbor_ids = lane_segment.right_neighbors
for right_neighbor_id in right_neighbor_ids:
right_idx = safe_list_index(polygon_ids, right_neighbor_id)
if right_idx is not None:
polygon_to_polygon_edge_index.append(
torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long))
polygon_to_polygon_type.append(
torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8))
if len(polygon_to_polygon_edge_index) != 0:
polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1)
polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0)
else:
polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long)
polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8)
map_data = {
'map_polygon': {},
'map_point': {},
('map_point', 'to', 'map_polygon'): {},
('map_polygon', 'to', 'map_polygon'): {},
}
map_data['map_polygon']['num_nodes'] = num_polygons
map_data['map_polygon']['type'] = polygon_type
map_data['map_polygon']['light_type'] = polygon_light_type
if len(num_points) == 0:
map_data['map_point']['num_nodes'] = 0
map_data['map_point']['position'] = torch.tensor([], dtype=torch.float)
map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float)
map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float)
if dim == 3:
map_data['map_point']['height'] = torch.tensor([], dtype=torch.float)
map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8)
map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8)
else:
map_data['map_point']['num_nodes'] = num_points.sum().item()
map_data['map_point']['position'] = torch.cat(point_position, dim=0)
map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0)
map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0)
if dim == 3:
map_data['map_point']['height'] = torch.cat(point_height, dim=0)
map_data['map_point']['type'] = torch.cat(point_type, dim=0)
map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index
map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index
map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type
# import matplotlib.pyplot as plt
# plt.axis('equal')
# plt.scatter(map_data['map_point']['position'][:, 0],
# map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')
# plt.show(dpi=600)
return map_data
def process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, start_timestamp, end_timestamp):
agents_array = track_info["trajs"].transpose(1, 0, 2)
object_id = np.array(track_info["object_id"])
object_type = track_info["object_type"]
id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))}
def type_hash(x):
tp = id_hash[x]
type_re_hash = {
"TYPE_VEHICLE": "vehicle",
"TYPE_PEDESTRIAN": "pedestrian",
"TYPE_CYCLIST": "cyclist",
"TYPE_OTHER": "background",
"TYPE_UNSET": "background"
}
return type_re_hash[tp]
columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep',
'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y',
'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps',
'focal_track_id', 'city']
new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11))
new_columns[:11, :, 0] = True
new_columns[11:, :, 0] = False
for index in range(new_columns.shape[0]):
new_columns[index, :, 4] = int(index)
new_columns[..., 1] = object_id
new_columns[..., 2] = object_id
new_columns[:, tracks_to_predict["track_index"], 3] = 3
new_columns[..., 5] = 11
new_columns[..., 6] = int(start_timestamp)
new_columns[..., 7] = int(end_timestamp)
new_columns[..., 8] = int(91)
new_columns[..., 9] = object_id
new_columns[..., 10] = 10086
new_columns = new_columns
new_agents_array = np.concatenate([new_columns, agents_array], axis=-1)
new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1])
new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10]]
new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns)
new_agents_array["object_type"] = new_agents_array["object_type"].apply(func=type_hash)
new_agents_array["start_timestamp"] = new_agents_array["start_timestamp"].astype(int)
new_agents_array["end_timestamp"] = new_agents_array["end_timestamp"].astype(int)
new_agents_array["num_timestamps"] = new_agents_array["num_timestamps"].astype(int)
new_agents_array["scenario_id"] = scenario_id
return new_agents_array
def process_dynamic_map(dynamic_map_infos):
lane_ids = dynamic_map_infos["lane_id"]
tf_lights = []
for t in range(len(lane_ids)):
lane_id = lane_ids[t]
time = np.ones_like(lane_id) * t
state = dynamic_map_infos["state"][t]
tf_light = np.concatenate([lane_id, time, state], axis=0)
tf_lights.append(tf_light)
tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0)
tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"])
tf_lights["time_step"] = tf_lights["time_step"].astype("str")
tf_lights["lane_id"] = tf_lights["lane_id"].astype("str")
tf_lights["state"] = tf_lights["state"].astype("str")
tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"] ] = 'LANE_STATE_STOP'
tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"] ] = 'LANE_STATE_GO'
tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"] ] = 'LANE_STATE_CAUTION'
return tf_lights
polyline_type = {
# for lane
'TYPE_UNDEFINED': -1,
'TYPE_FREEWAY': 1,
'TYPE_SURFACE_STREET': 2,
'TYPE_BIKE_LANE': 3,
# for roadline
'TYPE_UNKNOWN': -1,
'TYPE_BROKEN_SINGLE_WHITE': 6,
'TYPE_SOLID_SINGLE_WHITE': 7,
'TYPE_SOLID_DOUBLE_WHITE': 8,
'TYPE_BROKEN_SINGLE_YELLOW': 9,
'TYPE_BROKEN_DOUBLE_YELLOW': 10,
'TYPE_SOLID_SINGLE_YELLOW': 11,
'TYPE_SOLID_DOUBLE_YELLOW': 12,
'TYPE_PASSING_DOUBLE_YELLOW': 13,
# for roadedge
'TYPE_ROAD_EDGE_BOUNDARY': 15,
'TYPE_ROAD_EDGE_MEDIAN': 16,
# for stopsign
'TYPE_STOP_SIGN': 17,
# for crosswalk
'TYPE_CROSSWALK': 18,
# for speed bump
'TYPE_SPEED_BUMP': 19
}
object_type = {
0: 'TYPE_UNSET',
1: 'TYPE_VEHICLE',
2: 'TYPE_PEDESTRIAN',
3: 'TYPE_CYCLIST',
4: 'TYPE_OTHER'
}
signal_state = {
0: 'LANE_STATE_UNKNOWN',
# // States for traffic signals with arrows.
1: 'LANE_STATE_ARROW_STOP',
2: 'LANE_STATE_ARROW_CAUTION',
3: 'LANE_STATE_ARROW_GO',
# // Standard round traffic signals.
4: 'LANE_STATE_STOP',
5: 'LANE_STATE_CAUTION',
6: 'LANE_STATE_GO',
# // Flashing light signals.
7: 'LANE_STATE_FLASHING_STOP',
8: 'LANE_STATE_FLASHING_CAUTION'
}
signal_state_to_id = {}
for key, val in signal_state.items():
signal_state_to_id[val] = key
def decode_tracks_from_proto(tracks):
track_infos = {
'object_id': [], # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others}
'object_type': [],
'trajs': []
}
for cur_data in tracks: # number of objects
cur_traj = [np.array([x.center_x, x.center_y, x.center_z, x.length, x.width, x.height, x.heading,
x.velocity_x, x.velocity_y, x.valid], dtype=np.float32) for x in cur_data.states]
cur_traj = np.stack(cur_traj, axis=0) # (num_timestamp, 10)
track_infos['object_id'].append(cur_data.id)
track_infos['object_type'].append(object_type[cur_data.object_type])
track_infos['trajs'].append(cur_traj)
track_infos['trajs'] = np.stack(track_infos['trajs'], axis=0) # (num_objects, num_timestamp, 9)
return track_infos
from collections import defaultdict
def decode_map_features_from_proto(map_features):
map_infos = {
'lane': [],
'road_line': [],
'road_edge': [],
'stop_sign': [],
'crosswalk': [],
'speed_bump': [],
'lane_dict': {},
'lane2other_dict': {}
}
polylines = []
point_cnt = 0
lane2other_dict = defaultdict(list)
for cur_data in map_features:
cur_info = {'id': cur_data.id}
if cur_data.lane.ByteSize() > 0:
cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph
cur_info['type'] = cur_data.lane.type + 1 # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane
cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors]
cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors]
cur_info['interpolating'] = cur_data.lane.interpolating
cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes)
cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes)
cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries]
cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries]
cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries]
cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries]
cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries]
cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries]
cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries]
cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries]
lane2other_dict[cur_data.id].extend(cur_info['left_boundary'])
lane2other_dict[cur_data.id].extend(cur_info['right_boundary'])
global_type = cur_info['type']
cur_polyline = np.stack(
[np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline],
axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['lane'].append(cur_info)
map_infos['lane_dict'][cur_data.id] = cur_info
elif cur_data.road_line.ByteSize() > 0:
cur_info['type'] = cur_data.road_line.type + 5
global_type = cur_info['type']
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
cur_data.road_line.polyline], axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['road_line'].append(cur_info)
elif cur_data.road_edge.ByteSize() > 0:
cur_info['type'] = cur_data.road_edge.type + 14
global_type = cur_info['type']
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
cur_data.road_edge.polyline], axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['road_edge'].append(cur_info)
elif cur_data.stop_sign.ByteSize() > 0:
cur_info['lane_ids'] = list(cur_data.stop_sign.lane)
for i in cur_info['lane_ids']:
lane2other_dict[i].append(cur_data.id)
point = cur_data.stop_sign.position
cur_info['position'] = np.array([point.x, point.y, point.z])
global_type = polyline_type['TYPE_STOP_SIGN']
cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5)
if cur_polyline.shape[0] <= 1:
continue
map_infos['stop_sign'].append(cur_info)
elif cur_data.crosswalk.ByteSize() > 0:
global_type = polyline_type['TYPE_CROSSWALK']
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
cur_data.crosswalk.polygon], axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['crosswalk'].append(cur_info)
elif cur_data.speed_bump.ByteSize() > 0:
global_type = polyline_type['TYPE_SPEED_BUMP']
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
cur_data.speed_bump.polygon], axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['speed_bump'].append(cur_info)
else:
# print(cur_data)
continue
polylines.append(cur_polyline)
cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline))
point_cnt += len(cur_polyline)
# try:
polylines = np.concatenate(polylines, axis=0).astype(np.float32)
# except:
# polylines = np.zeros((0, 8), dtype=np.float32)
# print('Empty polylines: ')
map_infos['all_polylines'] = polylines
map_infos['lane2other_dict'] = lane2other_dict
return map_infos
def decode_dynamic_map_states_from_proto(dynamic_map_states):
dynamic_map_infos = {
'lane_id': [],
'state': [],
'stop_point': []
}
for cur_data in dynamic_map_states: # (num_timestamp)
lane_id, state, stop_point = [], [], []
for cur_signal in cur_data.lane_states: # (num_observed_signals)
lane_id.append(cur_signal.lane)
state.append(signal_state[cur_signal.state])
stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z])
dynamic_map_infos['lane_id'].append(np.array([lane_id]))
dynamic_map_infos['state'].append(np.array([state]))
dynamic_map_infos['stop_point'].append(np.array([stop_point]))
return dynamic_map_infos
def process_single_data(scenario):
info = {}
info['scenario_id'] = scenario.scenario_id
info['timestamps_seconds'] = list(scenario.timestamps_seconds) # list of int of shape (91)
info['current_time_index'] = scenario.current_time_index # int, 10
info['sdc_track_index'] = scenario.sdc_track_index # int
info['objects_of_interest'] = list(scenario.objects_of_interest) # list, could be empty list
info['tracks_to_predict'] = {
'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict],
'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict]
} # for training: suggestion of objects to train on, for val/test: need to be predicted
track_infos = decode_tracks_from_proto(scenario.tracks)
info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in
info['tracks_to_predict']['track_index']]
# decode map related data
map_infos = decode_map_features_from_proto(scenario.map_features)
dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states)
save_infos = {
'track_infos': track_infos,
'dynamic_map_infos': dynamic_map_infos,
'map_infos': map_infos
}
save_infos.update(info)
return save_infos
import tensorflow as tf
from waymo_open_dataset.protos import scenario_pb2
def wm2argo(file, dir_name, output_dir):
file_path = os.path.join(dir_name, file)
dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3)
for cnt, data in enumerate(dataset):
print(cnt)
scenario = scenario_pb2.Scenario()
scenario.ParseFromString(bytearray(data.numpy()))
save_infos = process_single_data(scenario) # pkl2mtr
map_info = save_infos["map_infos"]
track_info = save_infos['track_infos']
scenario_id = save_infos['scenario_id']
tracks_to_predict = save_infos['tracks_to_predict']
sdc_track_index = save_infos['sdc_track_index']
av_id = track_info["object_id"][sdc_track_index]
if len(tracks_to_predict["track_index"]) < 1:
return
dynamic_map_infos = save_infos["dynamic_map_infos"]
tf_lights = process_dynamic_map(dynamic_map_infos)
tf_current_light = tf_lights.loc[tf_lights["time_step"] == "11"]
map_data = get_map_features(map_info, tf_current_light)
new_agents_array = process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, 0, 91) # mtr2argo
data = dict()
data['scenario_id'] = new_agents_array['scenario_id'].values[0]
data['city'] = new_agents_array['city'].values[0]
data['agent'] = get_agent_features(new_agents_array, av_id, num_historical_steps=11)
data.update(map_data)
with open(os.path.join(output_dir, scenario_id + '.pkl'), "wb+") as f:
pickle.dump(data, f)
def batch_process9s_transformer(dir_name, output_dir, num_workers=2):
from functools import partial
import multiprocessing
packages = os.listdir(dir_name)
func = partial(
wm2argo, output_dir=output_dir, dir_name=dir_name)
with multiprocessing.Pool(num_workers) as p:
list(tqdm(p.imap(func, packages), total=len(packages)))
from argparse import ArgumentParser
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--input_dir', type=str, default='data/waymo/scenario/training')
parser.add_argument('--output_dir', type=str, default='data/waymo_processed/training')
args = parser.parse_args()
files = os.listdir(args.input_dir)
for file in tqdm(files):
wm2argo(file, args.input_dir, args.output_dir)
# batch_process9s_transformer(args.input_dir, args.output_dir, num_workers="ur_cpu_count")
================================================
FILE: environment.yml
================================================
name: smart
channels:
- pytorch
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotli-python=1.0.9=py39h6a678d5_8
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.9.24=h06a4308_0
- certifi=2024.8.30=py39h06a4308_0
- charset-normalizer=3.3.2=pyhd3eb1b0_0
- cudatoolkit=11.3.1=h2bc3f7f_2
- ffmpeg=4.3=hf484d3e_0
- freetype=2.12.1=h4a9f257_0
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- idna=3.7=py39h06a4308_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- jpeg=9e=h5eee18b_3
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.40=h12ee557_0
- lerc=3.0=h295c915_0
- libdeflate=1.17=h5eee18b_1
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.14=0
- libidn2=2.3.4=h5eee18b_0
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libwebp-base=1.3.2=h5eee18b_1
- lz4-c=1.9.4=h6a678d5_1
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py39h5eee18b_1
- mkl_fft=1.3.10=py39h5eee18b_0
- mkl_random=1.2.7=py39h1128e8f_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.5.2=he7f1fd0_0
- openssl=3.0.15=h5eee18b_0
- pillow=10.4.0=py39h5eee18b_0
- pip=24.2=py39h06a4308_0
- pysocks=1.7.1=py39h06a4308_0
- python=3.9.19=h955ad1f_1
- pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0
- pytorch-mutex=1.0=cuda
- readline=8.2=h5eee18b_0
- requests=2.32.3=py39h06a4308_0
- setuptools=75.1.0=py39h06a4308_0
- sqlite=3.45.3=h5eee18b_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.14=h39e8969_0
- torchvision=0.13.1=py39_cu113
- typing_extensions=4.11.0=py39h06a4308_0
- urllib3=2.2.3=py39h06a4308_0
- wheel=0.44.0=py39h06a4308_0
- xz=5.4.6=h5eee18b_1
- zlib=1.2.13=h5eee18b_1
- zstd=1.5.6=hc292b87_0
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "smart"
version = "0.0.0"
description = "Scalable Multi-agent Real-time Motion Generation via Next-token Prediction"
readme = "README.md"
authors = [
{name = "Xiaoxin Feng"},
{name = "Ziyan Gao"},
{name = "Yuheng Kan"}
]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
]
requires-python = ">=3.9"
dependencies = [
"easydict",
"numpy",
"pandas",
"pytorch-lightning",
"scipy",
"torch-cluster",
"torch-geometric",
"torch-scatter",
"torch",
"torchmetrics",
"tqdm",
]
[project.urls]
"Homepage" = "https://smart-motion.github.io/smart/"
"Repository" = "https://github.com/rainmaker22/SMART"
"Paper" = "https://arxiv.org/abs/2405.15677"
[tool.setuptools]
packages = ["smart"]
================================================
FILE: requirements.txt
================================================
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
async-timeout==4.0.3
attrs==24.2.0
contourpy==1.3.0
cycler==0.12.1
easydict==1.13
fonttools==4.54.1
frozenlist==1.4.1
fsspec==2024.10.0
importlib-resources==6.4.5
jinja2==3.1.4
kiwisolver==1.4.7
lightning-utilities==0.11.8
markupsafe==3.0.2
matplotlib==3.9.2
multidict==6.1.0
numpy==1.26.4
packaging==24.1
pandas==2.0.3
propcache==0.2.0
psutil==6.1.0
pyparsing==3.2.0
python-dateutil==2.9.0.post0
pytorch-lightning==2.0.3
pytz==2024.2
pyyaml==6.0.1
scipy==1.10.1
shapely==2.0.6
six==1.16.0
torch-cluster==1.6.0+pt112cu113
torch-geometric==2.6.1
torch-scatter==2.1.0+pt112cu113
torch-sparse==0.6.16+pt112cu113
torch-spline-conv==1.2.1+pt112cu113
torchmetrics==1.5.0
tqdm==4.66.5
tzdata==2024.2
yarl==1.16.0
zipp==3.20.2
waymo-open-dataset-tf-2-12-0==1.6.4
================================================
FILE: scripts/install_pyg.sh
================================================
mkdir pyg_depend && cd pyg_depend
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.16%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_cluster-1.6.0+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_scatter-2.1.0+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_sparse-0.6.16+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_spline_conv-1.2.1+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_geometric
================================================
FILE: scripts/traj_clstering.py
================================================
from smart.utils.geometry import wrap_angle
import numpy as np
def average_distance_vectorized(point_set1, centroids):
dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :])**2, axis=-1))
return np.mean(dists, axis=2)
def assign_clusters(sub_X, centroids):
distances = average_distance_vectorized(sub_X, centroids)
return np.argmin(distances, axis=1)
def Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None):
S = []
ret_traj_list = []
while len(S) < N:
num_all = X.shape[0]
# 随机选择第一个簇中心
choice_index = np.random.choice(num_all)
x0 = X[choice_index]
if x0[0, 0] < -10 or x0[0, 0] > 50 or x0[0, 1] > 10 or x0[0, 1] < -10:
continue
res_mask = np.sum((X - x0)**2, axis=(1, 2))/4 > (tol**2)
del_mask = np.sum((X - x0)**2, axis=(1, 2))/4 <= (tol**2)
if cal_mean_heading:
del_contour = X[del_mask]
diff_xy = del_contour[:, 0, :] - del_contour[:, 3, :]
del_heading = np.arctan2(diff_xy[:, 1], diff_xy[:, 0]).mean()
x0 = cal_polygon_contour(x0.mean(0)[0], x0.mean(0)[1], del_heading, width, length)
del_traj = a_pos[del_mask]
ret_traj = del_traj.mean(0)[None, ...]
if abs(ret_traj[0, 1, 0] - ret_traj[0, 0, 0]) > 1 and ret_traj[0, 1, 0] < 0:
print(ret_traj)
print('1')
else:
x0 = x0[None, ...]
ret_traj = a_pos[choice_index][None, ...]
X = X[res_mask]
a_pos = a_pos[res_mask]
S.append(x0)
ret_traj_list.append(ret_traj)
centroids = np.concatenate(S, axis=0)
ret_traj = np.concatenate(ret_traj_list, axis=0)
# closest_dist_sq = np.sum((X - centroids[0])**2, axis=(1, 2))
# for k in range(1, K):
# new_dist_sq = np.sum((X - centroids[k - 1])**2, axis=(1, 2))
# closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq)
# probabilities = closest_dist_sq / np.sum(closest_dist_sq)
# centroids[k] = X[np.random.choice(N, p=probabilities)]
return centroids, ret_traj
def cal_polygon_contour(x, y, theta, width, length):
left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
left_front = np.column_stack((left_front_x, left_front_y))
right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
right_front = np.column_stack((right_front_x, right_front_y))
right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
right_back = np.column_stack((right_back_x, right_back_y))
left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
left_back = np.column_stack((left_back_x, left_back_y))
polygon_contour = np.concatenate((left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)
return polygon_contour
if __name__ == '__main__':
shift = 5 # motion token time dimension
num_cluster = 6 # vocabulary size
cal_mean_heading = True
data = {
"veh": np.random.rand(1000, 6, 3),
"cyc": np.random.rand(1000, 6, 3),
"ped": np.random.rand(1000, 6, 3)
}
# Collect the trajectories of all traffic participants from the raw data [NumAgent, shift+1, [relative_x, relative_y, relative_theta]]
nms_res = {}
res = {'token': {}, 'traj': {}, 'token_all': {}}
for k, v in data.items():
# if k != 'veh':
# continue
a_pos = v
print(a_pos.shape)
# a_pos = a_pos[:, shift:1+shift, :]
cal_num = min(int(1e6), a_pos.shape[0])
a_pos = a_pos[np.random.choice(a_pos.shape[0], cal_num, replace=False)]
a_pos[:, :, -1] = wrap_angle(a_pos[:, :, -1])
print(a_pos.shape)
if shift <= 2:
if k == 'veh':
width = 1.0
length = 2.4
elif k == 'cyc':
width = 0.5
length = 1.5
else:
width = 0.5
length = 0.5
else:
if k == 'veh':
width = 2.0
length = 4.8
elif k == 'cyc':
width = 1.0
length = 2.0
else:
width = 1.0
length = 1.0
contour = cal_polygon_contour(a_pos[:, shift, 0], a_pos[:, shift, 1], a_pos[:, shift, 2], width, length)
# plt.figure(figsize=(10, 10))
# for rect in contour:
# rect_closed = np.vstack([rect, rect[0]])
# plt.plot(rect_closed[:, 0], rect_closed[:, 1], linewidth=0.1)
# plt.title("Plot of 256 Rectangles")
# plt.xlabel("x")
# plt.ylabel("y")
# plt.axis('equal')
# plt.savefig(f'src_{k}_new.jpg', dpi=300)
if k == 'veh':
tol = 0.05
elif k == 'cyc':
tol = 0.004
else:
tol = 0.004
centroids, ret_traj = Kdisk_cluster(contour, num_cluster, tol, width, length, a_pos[:, :shift+1])
# plt.figure(figsize=(10, 10))
contour = cal_polygon_contour(ret_traj[:, :, 0].reshape(num_cluster*(shift+1)),
ret_traj[:, :, 1].reshape(num_cluster*(shift+1)),
ret_traj[:, :, 2].reshape(num_cluster*(shift+1)), width, length)
res['token_all'][k] = contour.reshape(num_cluster, (shift+1), 4, 2)
res['token'][k] = centroids
res['traj'][k] = ret_traj
================================================
FILE: smart/__init__.py
================================================
================================================
FILE: smart/datamodules/__init__.py
================================================
from smart.datamodules.scalable_datamodule import MultiDataModule
================================================
FILE: smart/datamodules/scalable_datamodule.py
================================================
from typing import Optional
import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
from smart.datasets.scalable_dataset import MultiDataset
from smart.transforms import WaymoTargetBuilder
class MultiDataModule(pl.LightningDataModule):
transforms = {
"WaymoTargetBuilder": WaymoTargetBuilder,
}
dataset = {
"scalable": MultiDataset,
}
def __init__(self,
root: str,
train_batch_size: int,
val_batch_size: int,
test_batch_size: int,
shuffle: bool = False,
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = True,
train_raw_dir: Optional[str] = None,
val_raw_dir: Optional[str] = None,
test_raw_dir: Optional[str] = None,
train_processed_dir: Optional[str] = None,
val_processed_dir: Optional[str] = None,
test_processed_dir: Optional[str] = None,
transform: Optional[str] = None,
dataset: Optional[str] = None,
num_historical_steps: int = 50,
num_future_steps: int = 60,
processor='ntp',
use_intention=False,
token_size=512,
**kwargs) -> None:
super(MultiDataModule, self).__init__()
self.root = root
self.dataset_class = dataset
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.test_batch_size = test_batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers and num_workers > 0
self.train_raw_dir = train_raw_dir
self.val_raw_dir = val_raw_dir
self.test_raw_dir = test_raw_dir
self.train_processed_dir = train_processed_dir
self.val_processed_dir = val_processed_dir
self.test_processed_dir = test_processed_dir
self.processor = processor
self.use_intention = use_intention
self.token_size = token_size
train_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "train")
val_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "val")
test_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps)
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
def setup(self, stage: Optional[str] = None) -> None:
self.train_dataset = MultiDataModule.dataset[self.dataset_class](self.root, 'train', processed_dir=self.train_processed_dir,
raw_dir=self.train_raw_dir, processor=self.processor, transform=self.train_transform, token_size=self.token_size)
self.val_dataset = MultiDataModule.dataset[self.dataset_class](None, 'val', processed_dir=self.val_processed_dir,
raw_dir=self.val_raw_dir, processor=self.processor, transform=self.val_transform, token_size=self.token_size)
self.test_dataset = MultiDataModule.dataset[self.dataset_class](None, 'test', processed_dir=self.test_processed_dir,
raw_dir=self.test_raw_dir, processor=self.processor, transform=self.test_transform, token_size=self.token_size)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle,
num_workers=self.num_workers, pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False,
num_workers=self.num_workers, pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False,
num_workers=self.num_workers, pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers)
================================================
FILE: smart/datasets/__init__.py
================================================
from smart.datasets.scalable_dataset import MultiDataset
================================================
FILE: smart/datasets/preprocess.py
================================================
import torch
import numpy as np
from scipy.interpolate import interp1d
from scipy.spatial.distance import euclidean
import math
import pickle
from smart.utils import wrap_angle
import os
def cal_polygon_contour(x, y, theta, width, length):
left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
left_front = np.column_stack((left_front_x, left_front_y))
right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
right_front = np.column_stack((right_front_x, right_front_y))
right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
right_back = np.column_stack((right_back_x, right_back_y))
left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
left_back = np.column_stack((left_back_x, left_back_y))
polygon_contour = np.concatenate(
(left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)
return polygon_contour
def interplating_polyline(polylines, heading, distance=0.5, split_distace=5):
# Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter
dist_along_path_list = [[0]]
polylines_list = [[polylines[0]]]
for i in range(1, polylines.shape[0]):
euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2])
heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi))
if heading_diff > math.pi / 4 and euclidean_dist > 3:
dist_along_path_list.append([0])
polylines_list.append([polylines[i]])
elif heading_diff > math.pi / 8 and euclidean_dist > 3:
dist_along_path_list.append([0])
polylines_list.append([polylines[i]])
elif heading_diff > 0.1 and euclidean_dist > 3:
dist_along_path_list.append([0])
polylines_list.append([polylines[i]])
elif euclidean_dist > 10:
dist_along_path_list.append([0])
polylines_list.append([polylines[i]])
else:
dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
polylines_list[-1].append(polylines[i])
# plt.plot(polylines[:, 0], polylines[:, 1])
# plt.savefig('tmp.jpg')
new_x_list = []
new_y_list = []
multi_polylines_list = []
for idx in range(len(dist_along_path_list)):
if len(dist_along_path_list[idx]) < 2:
continue
dist_along_path = np.array(dist_along_path_list[idx])
polylines_cur = np.array(polylines_list[idx])
# Create interpolation functions for x and y coordinates
fx = interp1d(dist_along_path, polylines_cur[:, 0])
fy = interp1d(dist_along_path, polylines_cur[:, 1])
# fyaw = interp1d(dist_along_path, heading)
# Create an array of distances at which to interpolate
new_dist_along_path = np.arange(0, dist_along_path[-1], distance)
new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]])
# Use the interpolation functions to generate new x and y coordinates
new_x = fx(new_dist_along_path)
new_y = fy(new_dist_along_path)
# new_yaw = fyaw(new_dist_along_path)
new_x_list.append(new_x)
new_y_list.append(new_y)
# Combine the new x and y coordinates into a single array
new_polylines = np.vstack((new_x, new_y)).T
polyline_size = int(split_distace / distance)
if new_polylines.shape[0] >= (polyline_size + 1):
padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
else:
padding_size = new_polylines.shape[0]
final_index = 0
multi_polylines = None
new_polylines = torch.from_numpy(new_polylines)
new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
new_polylines[1:, 0] - new_polylines[:-1, 0])
new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
new_polylines = torch.cat([new_polylines, new_heading], -1)
if new_polylines.shape[0] >= (polyline_size + 1):
multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
multi_polylines = multi_polylines.transpose(1, 2)
multi_polylines = multi_polylines[:, ::5, :]
if padding_size >= 3:
last_polyline = new_polylines[final_index * polyline_size:]
last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
if multi_polylines is not None:
multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
else:
multi_polylines = last_polyline.unsqueeze(0)
if multi_polylines is None:
continue
multi_polylines_list.append(multi_polylines)
if len(multi_polylines_list) > 0:
multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
else:
multi_polylines_list = None
return multi_polylines_list
def average_distance_vectorized(point_set1, centroids):
dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1))
return np.mean(dists, axis=2)
def assign_clusters(sub_X, centroids):
distances = average_distance_vectorized(sub_X, centroids)
return np.argmin(distances, axis=1)
class TokenProcessor:
def __init__(self, token_size):
module_dir = os.path.dirname(os.path.dirname(__file__))
self.agent_token_path = os.path.join(module_dir, f'tokens/cluster_frame_5_{token_size}.pkl')
self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
self.noise = False
self.disturb = False
self.shift = 5
self.get_trajectory_token()
self.training = False
self.current_step = 10
def preprocess(self, data):
data = self.tokenize_agent(data)
data = self.tokenize_map(data)
del data['city']
if 'polygon_is_intersection' in data['map_polygon']:
del data['map_polygon']['polygon_is_intersection']
if 'route_type' in data['map_polygon']:
del data['map_polygon']['route_type']
return data
def get_trajectory_token(self):
agent_token_data = pickle.load(open(self.agent_token_path, 'rb'))
map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
self.trajectory_token = agent_token_data['token']
self.trajectory_token_all = agent_token_data['token_all']
self.map_token = {'traj_src': map_token_traj['traj_src'], }
self.token_last = {}
for k, v in self.trajectory_token_all.items():
token_last = torch.from_numpy(v[:, -2:]).to(torch.float)
diff_xy = token_last[:, 0, 0] - token_last[:, 0, 3]
theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
cos, sin = theta.cos(), theta.sin()
rot_mat = theta.new_zeros(token_last.shape[0], 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = -sin
rot_mat[:, 1, 0] = sin
rot_mat[:, 1, 1] = cos
agent_token = torch.bmm(token_last[:, 1], rot_mat)
agent_token -= token_last[:, 0].mean(1)[:, None, :]
self.token_last[k] = agent_token.numpy()
def clean_heading(self, data):
heading = data['agent']['heading']
valid = data['agent']['valid_mask']
pi = torch.tensor(torch.pi)
n_vehicles, n_frames = heading.shape
heading_diff_raw = heading[:, :-1] - heading[:, 1:]
heading_diff = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
heading_diff[heading_diff > pi] -= 2 * pi
heading_diff[heading_diff < -pi] += 2 * pi
valid_pairs = valid[:, :-1] & valid[:, 1:]
for i in range(n_frames - 1):
change_needed = (torch.abs(heading_diff[:, i:i + 1]) > 1.0) & valid_pairs[:, i:i + 1]
heading[:, i + 1][change_needed.squeeze()] = heading[:, i][change_needed.squeeze()]
if i < n_frames - 2:
heading_diff_raw = heading[:, i + 1] - heading[:, i + 2]
heading_diff[:, i + 1] = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
heading_diff[heading_diff[:, i + 1] > pi] -= 2 * pi
heading_diff[heading_diff[:, i + 1] < -pi] += 2 * pi
def tokenize_agent(self, data):
if data['agent']["velocity"].shape[1] == 90:
print(data['scenario_id'], data['agent']["velocity"].shape)
interplote_mask = (data['agent']['valid_mask'][:, self.current_step] == False) * (
data['agent']['position'][:, self.current_step, 0] != 0)
if data['agent']["velocity"].shape[-1] == 2:
data['agent']["velocity"] = torch.cat([data['agent']["velocity"],
torch.zeros(data['agent']["velocity"].shape[0],
data['agent']["velocity"].shape[1], 1)], dim=-1)
vel = data['agent']["velocity"][interplote_mask, self.current_step]
data['agent']['position'][interplote_mask, self.current_step - 1, :3] = data['agent']['position'][
interplote_mask, self.current_step,
:3] - vel * 0.1
data['agent']['valid_mask'][interplote_mask, self.current_step - 1:self.current_step + 1] = True
data['agent']['heading'][interplote_mask, self.current_step - 1] = data['agent']['heading'][
interplote_mask, self.current_step]
data['agent']["velocity"][interplote_mask, self.current_step - 1] = data['agent']["velocity"][
interplote_mask, self.current_step]
data['agent']['type'] = data['agent']['type'].to(torch.uint8)
self.clean_heading(data)
matching_extra_mask = (data['agent']['valid_mask'][:, self.current_step] == True) * (
data['agent']['valid_mask'][:, self.current_step - 5] == False)
interplote_mask_first = (data['agent']['valid_mask'][:, 0] == False) * (data['agent']['position'][:, 0, 0] != 0)
data['agent']['valid_mask'][interplote_mask_first, 0] = True
agent_pos = data['agent']['position'][:, :, :2]
valid_mask = data['agent']['valid_mask']
valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift)
token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1]
agent_type = data['agent']['type']
agent_category = data['agent']['category']
agent_heading = data['agent']['heading']
vehicle_mask = agent_type == 0
cyclist_mask = agent_type == 2
ped_mask = agent_type == 1
veh_pos = agent_pos[vehicle_mask, :, :]
veh_valid_mask = valid_mask[vehicle_mask, :]
cyc_pos = agent_pos[cyclist_mask, :, :]
cyc_valid_mask = valid_mask[cyclist_mask, :]
ped_pos = agent_pos[ped_mask, :, :]
ped_valid_mask = valid_mask[ped_mask, :]
veh_token_index, veh_token_contour = self.match_token(veh_pos, veh_valid_mask, agent_heading[vehicle_mask],
'veh', agent_category[vehicle_mask],
matching_extra_mask[vehicle_mask])
ped_token_index, ped_token_contour = self.match_token(ped_pos, ped_valid_mask, agent_heading[ped_mask], 'ped',
agent_category[ped_mask], matching_extra_mask[ped_mask])
cyc_token_index, cyc_token_contour = self.match_token(cyc_pos, cyc_valid_mask, agent_heading[cyclist_mask],
'cyc', agent_category[cyclist_mask],
matching_extra_mask[cyclist_mask])
token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64)
token_index[vehicle_mask] = veh_token_index
token_index[ped_mask] = ped_token_index
token_index[cyclist_mask] = cyc_token_index
token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1],
veh_token_contour.shape[2], veh_token_contour.shape[3]))
token_contour[vehicle_mask] = veh_token_contour
token_contour[ped_mask] = ped_token_contour
token_contour[cyclist_mask] = cyc_token_contour
trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(torch.float)
trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(torch.float)
trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(torch.float)
agent_token_traj = torch.zeros((agent_pos.shape[0], trajectory_token_veh.shape[0], 4, 2))
agent_token_traj[vehicle_mask] = trajectory_token_veh
agent_token_traj[ped_mask] = trajectory_token_ped
agent_token_traj[cyclist_mask] = trajectory_token_cyc
if not self.training:
token_valid_mask[matching_extra_mask, 1] = True
data['agent']['token_idx'] = token_index
data['agent']['token_contour'] = token_contour
token_pos = token_contour.mean(dim=2)
data['agent']['token_pos'] = token_pos
diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :]
data['agent']['token_heading'] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
data['agent']['agent_valid_mask'] = token_valid_mask
vel = torch.cat([token_pos.new_zeros(data['agent']['num_nodes'], 1, 2),
((token_pos[:, 1:] - token_pos[:, :-1]) / (0.1 * self.shift))], dim=1)
vel_valid_mask = torch.cat([torch.zeros(token_valid_mask.shape[0], 1, dtype=torch.bool),
(token_valid_mask * token_valid_mask.roll(shifts=1, dims=1))[:, 1:]], dim=1)
vel[~vel_valid_mask] = 0
vel[data['agent']['valid_mask'][:, self.current_step], 1] = data['agent']['velocity'][
data['agent']['valid_mask'][:, self.current_step],
self.current_step, :2]
data['agent']['token_velocity'] = vel
return data
def match_token(self, pos, valid_mask, heading, category, agent_category, extra_mask):
agent_token_src = self.trajectory_token[category]
token_last = self.token_last[category]
if self.shift <= 2:
if category == 'veh':
width = 1.0
length = 2.4
elif category == 'cyc':
width = 0.5
length = 1.5
else:
width = 0.5
length = 0.5
else:
if category == 'veh':
width = 2.0
length = 4.8
elif category == 'cyc':
width = 1.0
length = 2.0
else:
width = 1.0
length = 1.0
prev_heading = heading[:, 0]
prev_pos = pos[:, 0]
agent_num, num_step, feat_dim = pos.shape
token_num, token_contour_dim, feat_dim = agent_token_src.shape
agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0)
token_last = token_last.reshape(1, token_num * token_contour_dim, feat_dim).repeat(extra_mask.sum(), 0)
token_index_list = []
token_contour_list = []
prev_token_idx = None
for i in range(self.shift, pos.shape[1], self.shift):
theta = prev_heading
cur_heading = heading[:, i]
cur_pos = pos[:, i]
cos, sin = theta.cos(), theta.sin()
rot_mat = theta.new_zeros(agent_num, 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
agent_token_world = torch.bmm(torch.from_numpy(agent_token_src).to(torch.float), rot_mat).reshape(agent_num,
token_num,
token_contour_dim,
feat_dim)
agent_token_world += prev_pos[:, None, None, :]
cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
agent_token_index = torch.from_numpy(np.argmin(
np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
axis=-1))
if prev_token_idx is not None and self.noise:
same_idx = prev_token_idx == agent_token_index
same_idx[:] = True
topk_indices = np.argsort(
np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)),
axis=2), axis=-1)[:, :5]
sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0])
agent_token_index[same_idx] = \
torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx]
token_contour_select = agent_token_world[torch.arange(agent_num), agent_token_index]
diff_xy = token_contour_select[:, 0, :] - token_contour_select[:, 3, :]
prev_heading = heading[:, i].clone()
prev_heading[valid_mask[:, i - self.shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[
valid_mask[:, i - self.shift]]
prev_pos = pos[:, i].clone()
prev_pos[valid_mask[:, i - self.shift]] = token_contour_select.mean(dim=1)[valid_mask[:, i - self.shift]]
prev_token_idx = agent_token_index
token_index_list.append(agent_token_index[:, None])
token_contour_list.append(token_contour_select[:, None, ...])
token_index = torch.cat(token_index_list, dim=1)
token_contour = torch.cat(token_contour_list, dim=1)
# extra matching
if not self.training:
theta = heading[extra_mask, self.current_step - 1]
prev_pos = pos[extra_mask, self.current_step - 1]
cur_pos = pos[extra_mask, self.current_step]
cur_heading = heading[extra_mask, self.current_step]
cos, sin = theta.cos(), theta.sin()
rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape(
extra_mask.sum(), token_num, token_contour_dim, feat_dim)
agent_token_world += prev_pos[:, None, None, :]
cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
agent_token_index = torch.from_numpy(np.argmin(
np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
axis=-1))
token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index]
token_index[extra_mask, 1] = agent_token_index
token_contour[extra_mask, 1] = token_contour_select
return token_index, token_contour
def tokenize_map(self, data):
data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8)
data['map_point']['type'] = data['map_point']['type'].to(torch.uint8)
pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index']
pt_type = data['map_point']['type'].to(torch.uint8)
pt_side = torch.zeros_like(pt_type)
pt_pos = data['map_point']['position'][:, :2]
data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation'])
pt_heading = data['map_point']['orientation']
split_polyline_type = []
split_polyline_pos = []
split_polyline_theta = []
split_polyline_side = []
pl_idx_list = []
split_polygon_type = []
data['map_point']['type'].unique()
for i in sorted(np.unique(pt2pl[1])):
index = pt2pl[0, pt2pl[1] == i]
polygon_type = data['map_polygon']["type"][i]
cur_side = pt_side[index]
cur_type = pt_type[index]
cur_pos = pt_pos[index]
cur_heading = pt_heading[index]
for side_val in np.unique(cur_side):
for type_val in np.unique(cur_type):
if type_val == 13:
continue
indices = np.where((cur_side == side_val) & (cur_type == type_val))[0]
if len(indices) <= 2:
continue
split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy())
if split_polyline is None:
continue
new_cur_type = cur_type[indices][0]
new_cur_side = cur_side[indices][0]
map_polygon_type = polygon_type.repeat(split_polyline.shape[0])
new_cur_type = new_cur_type.repeat(split_polyline.shape[0])
new_cur_side = new_cur_side.repeat(split_polyline.shape[0])
cur_pl_idx = torch.Tensor([i])
new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0])
split_polyline_pos.append(split_polyline[..., :2])
split_polyline_theta.append(split_polyline[..., 2])
split_polyline_type.append(new_cur_type)
split_polyline_side.append(new_cur_side)
pl_idx_list.append(new_cur_pl_idx)
split_polygon_type.append(map_polygon_type)
split_polyline_pos = torch.cat(split_polyline_pos, dim=0)
split_polyline_theta = torch.cat(split_polyline_theta, dim=0)
split_polyline_type = torch.cat(split_polyline_type, dim=0)
split_polyline_side = torch.cat(split_polyline_side, dim=0)
split_polygon_type = torch.cat(split_polygon_type, dim=0)
pl_idx_list = torch.cat(pl_idx_list, dim=0)
vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :]
data['map_save'] = {}
data['pt_token'] = {}
data['map_save']['traj_pos'] = split_polyline_pos
data['map_save']['traj_theta'] = split_polyline_theta[:, 0] # torch.arctan2(vec[:, 1], vec[:, 0])
data['map_save']['pl_idx_list'] = pl_idx_list
data['pt_token']['type'] = split_polyline_type
data['pt_token']['side'] = split_polyline_side
data['pt_token']['pl_type'] = split_polygon_type
data['pt_token']['num_nodes'] = split_polyline_pos.shape[0]
return data
================================================
FILE: smart/datasets/scalable_dataset.py
================================================
import os
import pickle
from typing import Callable, List, Optional, Tuple, Union
import pandas as pd
from torch_geometric.data import Dataset
from smart.utils.log import Logging
import numpy as np
from .preprocess import TokenProcessor
def distance(point1, point2):
return np.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2)
class MultiDataset(Dataset):
def __init__(self,
root: str,
split: str,
raw_dir: List[str] = None,
processed_dir: List[str] = None,
transform: Optional[Callable] = None,
dim: int = 3,
num_historical_steps: int = 50,
num_future_steps: int = 60,
predict_unseen_agents: bool = False,
vector_repr: bool = True,
cluster: bool = False,
processor=None,
use_intention=False,
token_size=512) -> None:
self.logger = Logging().log(level='DEBUG')
self.root = root
self.well_done = [0]
if split not in ('train', 'val', 'test'):
raise ValueError(f'{split} is not a valid split')
self.split = split
self.training = split == 'train'
self.logger.debug("Starting loading dataset")
self._raw_file_names = []
self._raw_paths = []
self._raw_file_dataset = []
if raw_dir is not None:
self._raw_dir = raw_dir
for raw_dir in self._raw_dir:
raw_dir = os.path.expanduser(os.path.normpath(raw_dir))
dataset = "waymo"
file_list = os.listdir(raw_dir)
self._raw_file_names.extend(file_list)
self._raw_paths.extend([os.path.join(raw_dir, f) for f in file_list])
self._raw_file_dataset.extend([dataset for _ in range(len(file_list))])
if self.root is not None:
split_datainfo = os.path.join(root, "split_datainfo.pkl")
with open(split_datainfo, 'rb+') as f:
split_datainfo = pickle.load(f)
if split == "test":
split = "val"
self._processed_file_names = split_datainfo[split]
self.dim = dim
self.num_historical_steps = num_historical_steps
self._num_samples = len(self._processed_file_names) - 1 if processed_dir is not None else len(self._raw_file_names)
self.logger.debug("The number of {} dataset is ".format(split) + str(self._num_samples))
self.token_processor = TokenProcessor(2048)
super(MultiDataset, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None)
@property
def raw_dir(self) -> str:
return self._raw_dir
@property
def raw_paths(self) -> List[str]:
return self._raw_paths
@property
def raw_file_names(self) -> Union[str, List[str], Tuple]:
return self._raw_file_names
@property
def processed_file_names(self) -> Union[str, List[str], Tuple]:
return self._processed_file_names
def len(self) -> int:
return self._num_samples
def generate_ref_token(self):
pass
def get(self, idx: int):
with open(self.raw_paths[idx], 'rb') as handle:
data = pickle.load(handle)
data = self.token_processor.preprocess(data)
return data
================================================
FILE: smart/layers/__init__.py
================================================
from smart.layers.attention_layer import AttentionLayer
from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
from smart.layers.mlp_layer import MLPLayer
================================================
FILE: smart/layers/attention_layer.py
================================================
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
from smart.utils import weight_init
class AttentionLayer(MessagePassing):
def __init__(self,
hidden_dim: int,
num_heads: int,
head_dim: int,
dropout: float,
bipartite: bool,
has_pos_emb: bool,
**kwargs) -> None:
super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs)
self.num_heads = num_heads
self.head_dim = head_dim
self.has_pos_emb = has_pos_emb
self.scale = head_dim ** -0.5
self.to_q = nn.Linear(hidden_dim, head_dim * num_heads)
self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
self.to_v = nn.Linear(hidden_dim, head_dim * num_heads)
if has_pos_emb:
self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads)
self.to_s = nn.Linear(hidden_dim, head_dim * num_heads)
self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads)
self.to_out = nn.Linear(head_dim * num_heads, hidden_dim)
self.attn_drop = nn.Dropout(dropout)
self.ff_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim),
)
if bipartite:
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim)
else:
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
self.attn_prenorm_x_dst = self.attn_prenorm_x_src
if has_pos_emb:
self.attn_prenorm_r = nn.LayerNorm(hidden_dim)
self.attn_postnorm = nn.LayerNorm(hidden_dim)
self.ff_prenorm = nn.LayerNorm(hidden_dim)
self.ff_postnorm = nn.LayerNorm(hidden_dim)
self.apply(weight_init)
def forward(self,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
r: Optional[torch.Tensor],
edge_index: torch.Tensor) -> torch.Tensor:
if isinstance(x, torch.Tensor):
x_src = x_dst = self.attn_prenorm_x_src(x)
else:
x_src, x_dst = x
x_src = self.attn_prenorm_x_src(x_src)
x_dst = self.attn_prenorm_x_dst(x_dst)
x = x[1]
if self.has_pos_emb and r is not None:
r = self.attn_prenorm_r(r)
x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index))
x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x)))
return x
def message(self,
q_i: torch.Tensor,
k_j: torch.Tensor,
v_j: torch.Tensor,
r: Optional[torch.Tensor],
index: torch.Tensor,
ptr: Optional[torch.Tensor]) -> torch.Tensor:
if self.has_pos_emb and r is not None:
k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim)
v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim)
sim = (q_i * k_j).sum(dim=-1) * self.scale
attn = softmax(sim, index, ptr)
self.attention_weight = attn.sum(-1).detach()
attn = self.attn_drop(attn)
return v_j * attn.unsqueeze(-1)
def update(self,
inputs: torch.Tensor,
x_dst: torch.Tensor) -> torch.Tensor:
inputs = inputs.view(-1, self.num_heads * self.head_dim)
g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1)))
return inputs + g * (self.to_s(x_dst) - inputs)
def _attn_block(self,
x_src: torch.Tensor,
x_dst: torch.Tensor,
r: Optional[torch.Tensor],
edge_index: torch.Tensor) -> torch.Tensor:
q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim)
k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim)
v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim)
agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r)
return self.to_out(agg)
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
return self.ff_mlp(x)
================================================
FILE: smart/layers/fourier_embedding.py
================================================
import math
from typing import List, Optional
import torch
import torch.nn as nn
from smart.utils import weight_init
class FourierEmbedding(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int,
num_freq_bands: int) -> None:
super(FourierEmbedding, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None
self.mlps = nn.ModuleList(
[nn.Sequential(
nn.Linear(num_freq_bands * 2 + 1, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
)
for _ in range(input_dim)])
self.to_out = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
)
self.apply(weight_init)
def forward(self,
continuous_inputs: Optional[torch.Tensor] = None,
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
if continuous_inputs is None:
if categorical_embs is not None:
x = torch.stack(categorical_embs).sum(dim=0)
else:
raise ValueError('Both continuous_inputs and categorical_embs are None')
else:
x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi
# Warning: if your data are noisy, don't use learnable sinusoidal embedding
x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1)
continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim
for i in range(self.input_dim):
continuous_embs[i] = self.mlps[i](x[:, i])
x = torch.stack(continuous_embs).sum(dim=0)
if categorical_embs is not None:
x = x + torch.stack(categorical_embs).sum(dim=0)
return self.to_out(x)
class MLPEmbedding(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int) -> None:
super(MLPEmbedding, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, 128),
nn.LayerNorm(128),
nn.ReLU(inplace=True),
nn.Linear(128, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim))
self.apply(weight_init)
def forward(self,
continuous_inputs: Optional[torch.Tensor] = None,
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
if continuous_inputs is None:
if categorical_embs is not None:
x = torch.stack(categorical_embs).sum(dim=0)
else:
raise ValueError('Both continuous_inputs and categorical_embs are None')
else:
x = self.mlp(continuous_inputs)
if categorical_embs is not None:
x = x + torch.stack(categorical_embs).sum(dim=0)
return x
================================================
FILE: smart/layers/mlp_layer.py
================================================
import torch
import torch.nn as nn
from smart.utils import weight_init
class MLPLayer(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int,
output_dim: int) -> None:
super(MLPLayer, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, output_dim),
)
self.apply(weight_init)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
================================================
FILE: smart/metrics/__init__.py
================================================
from smart.metrics.average_meter import AverageMeter
from smart.metrics.min_ade import minADE
from smart.metrics.min_fde import minFDE
from smart.metrics.next_token_cls import TokenCls
================================================
FILE: smart/metrics/average_meter.py
================================================
import torch
from torchmetrics import Metric
class AverageMeter(Metric):
def __init__(self, **kwargs) -> None:
super(AverageMeter, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
def update(self, val: torch.Tensor) -> None:
self.sum += val.sum()
self.count += val.numel()
def compute(self) -> torch.Tensor:
return self.sum / self.count
================================================
FILE: smart/metrics/min_ade.py
================================================
from typing import Optional
import torch
from torchmetrics import Metric
from smart.metrics.utils import topk
from smart.metrics.utils import valid_filter
class minMultiADE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minMultiADE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True,
min_criterion: str = 'FDE') -> None:
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
pred_topk, _ = topk(self.max_guesses, pred, prob)
if min_criterion == 'FDE':
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
inds_best = torch.norm(
pred_topk[torch.arange(pred.size(0)), :, inds_last] -
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
elif min_criterion == 'ADE':
self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) *
valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
else:
raise ValueError('{} is not a valid criterion'.format(min_criterion))
self.count += pred.size(0)
def compute(self) -> torch.Tensor:
return self.sum / self.count
class minADE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minADE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
self.eval_timestep = 70
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True,
min_criterion: str = 'ADE') -> None:
# pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
# pred_topk, _ = topk(self.max_guesses, pred, prob)
# if min_criterion == 'FDE':
# inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
# inds_best = torch.norm(
# pred[torch.arange(pred.size(0)), :, inds_last] -
# target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
# self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
# valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
# elif min_criterion == 'ADE':
# self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) *
# valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
# else:
# raise ValueError('{} is not a valid criterion'.format(min_criterion))
eval_timestep = min(self.eval_timestep, pred.shape[1])
self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum()
self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum()
def compute(self) -> torch.Tensor:
return self.sum / self.count
================================================
FILE: smart/metrics/min_fde.py
================================================
from typing import Optional
import torch
from torchmetrics import Metric
from smart.metrics.utils import topk
from smart.metrics.utils import valid_filter
class minMultiFDE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minMultiFDE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True) -> None:
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
pred_topk, _ = topk(self.max_guesses, pred, prob)
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] -
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2),
p=2, dim=-1).min(dim=-1)[0].sum()
self.count += pred.size(0)
def compute(self) -> torch.Tensor:
return self.sum / self.count
class minFDE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minFDE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
self.eval_timestep = 70
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True) -> None:
eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1
self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) *
valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum()
self.count += valid_mask[:, eval_timestep-1].sum()
def compute(self) -> torch.Tensor:
return self.sum / self.count
================================================
FILE: smart/metrics/next_token_cls.py
================================================
from typing import Optional
import torch
from torchmetrics import Metric
from smart.metrics.utils import topk
from smart.metrics.utils import valid_filter
class TokenCls(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(TokenCls, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None) -> None:
target = target[..., None]
acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask
self.sum += acc.sum()
self.count += valid_mask.sum()
def compute(self) -> torch.Tensor:
return self.sum / self.count
================================================
FILE: smart/metrics/utils.py
================================================
from typing import Optional, Tuple
import torch
from torch_scatter import gather_csr
from torch_scatter import segment_csr
def topk(
max_guesses: int,
pred: torch.Tensor,
prob: Optional[torch.Tensor] = None,
ptr: Optional[torch.Tensor] = None,
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
max_guesses = min(max_guesses, pred.size(1))
if max_guesses == pred.size(1):
if prob is not None:
prob = prob / prob.sum(dim=-1, keepdim=True)
else:
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred, prob
else:
if prob is not None:
if joint:
if ptr is None:
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = inds_topk.repeat(pred.size(0), 1)
else:
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
reduce='mean'),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = gather_csr(src=inds_topk, indptr=ptr)
else:
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
else:
pred_topk = pred[:, :max_guesses]
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred_topk, prob_topk
def topkind(
max_guesses: int,
pred: torch.Tensor,
prob: Optional[torch.Tensor] = None,
ptr: Optional[torch.Tensor] = None,
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
max_guesses = min(max_guesses, pred.size(1))
if max_guesses == pred.size(1):
if prob is not None:
prob = prob / prob.sum(dim=-1, keepdim=True)
else:
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred, prob, None
else:
if prob is not None:
if joint:
if ptr is None:
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = inds_topk.repeat(pred.size(0), 1)
else:
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
reduce='mean'),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = gather_csr(src=inds_topk, indptr=ptr)
else:
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
else:
pred_topk = pred[:, :max_guesses]
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred_topk, prob_topk, inds_topk
def valid_filter(
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
ptr: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
torch.Tensor, torch.Tensor]:
if valid_mask is None:
valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool)
if keep_invalid_final_step:
filter_mask = valid_mask.any(dim=-1)
else:
filter_mask = valid_mask[:, -1]
pred = pred[filter_mask]
target = target[filter_mask]
if prob is not None:
prob = prob[filter_mask]
valid_mask = valid_mask[filter_mask]
if ptr is not None:
num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum')
ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,))
torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:])
else:
ptr = target.new_tensor([0, target.size(0)])
return pred, target, prob, valid_mask, ptr
def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):
"""
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 7)
pred_scores (batch_size, num_modes):
dist_thresh (float):
num_ret_modes (int, optional): Defaults to 6.
Returns:
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
ret_scores (batch_size, num_ret_modes)
ret_idxs (batch_size, num_ret_modes)
"""
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
pred_goals = pred_trajs[:, :, -1, :]
dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1)
nearby_neighbor = dist < dist_thresh
pred_scores = nearby_neighbor.sum(dim=-1) / num_modes
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
point_cover_mask = (dist < dist_thresh)
point_val = sorted_pred_scores.clone() # (batch_size, N)
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
for k in range(num_ret_modes):
cur_idx = point_val.argmax(dim=-1) # (batch_size)
ret_idxs[:, k] = cur_idx
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
point_val_selected[bs_idxs, cur_idx] = -1
point_val += point_val_selected
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
return ret_trajs, ret_scores, ret_idxs
def batch_nms(pred_trajs, pred_scores,
dist_thresh, num_ret_modes=6,
mode='static', speed=None):
"""
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 7)
pred_scores (batch_size, num_modes):
dist_thresh (float):
num_ret_modes (int, optional): Defaults to 6.
Returns:
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
ret_scores (batch_size, num_ret_modes)
ret_idxs (batch_size, num_ret_modes)
"""
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
if mode == "speed":
scale = torch.ones(batch_size).to(sorted_pred_goals.device)
lon_dist_thresh = 4 * scale
lat_dist_thresh = 0.5 * scale
lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1)
lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1)
point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None])
else:
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
point_cover_mask = (dist < dist_thresh)
point_val = sorted_pred_scores.clone() # (batch_size, N)
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
for k in range(num_ret_modes):
cur_idx = point_val.argmax(dim=-1) # (batch_size)
ret_idxs[:, k] = cur_idx
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
point_val_selected[bs_idxs, cur_idx] = -1
point_val += point_val_selected
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
return ret_trajs, ret_scores, ret_idxs
def batch_nms_token(pred_trajs, pred_scores,
dist_thresh, num_ret_modes=6,
mode='static', speed=None):
"""
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 7)
pred_scores (batch_size, num_modes):
dist_thresh (float):
num_ret_modes (int, optional): Defaults to 6.
Returns:
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
ret_scores (batch_size, num_ret_modes)
ret_idxs (batch_size, num_ret_modes)
"""
batch_size, num_modes, num_feat_dim = pred_trajs.shape
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
if mode == "nearby":
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
values, indices = torch.topk(dist, 5, dim=-1, largest=False)
thresh_hold = values[..., -1]
point_cover_mask = dist < thresh_hold[..., None]
else:
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
point_cover_mask = (dist < dist_thresh)
point_val = sorted_pred_scores.clone() # (batch_size, N)
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim)
ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes)
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
for k in range(num_ret_modes):
cur_idx = point_val.argmax(dim=-1) # (batch_size)
ret_idxs[:, k] = cur_idx
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
point_val_selected[bs_idxs, cur_idx] = -1
point_val += point_val_selected
ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx]
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
return ret_goals, ret_scores, ret_idxs
================================================
FILE: smart/model/__init__.py
================================================
from smart.model.smart import SMART
================================================
FILE: smart/model/smart.py
================================================
import contextlib
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch_geometric.data import Batch
from torch_geometric.data import HeteroData
from smart.metrics import minADE
from smart.metrics import minFDE
from smart.metrics import TokenCls
from smart.modules import SMARTDecoder
from torch.optim.lr_scheduler import LambdaLR
import math
import numpy as np
import pickle
from collections import defaultdict
import os
from waymo_open_dataset.protos import sim_agents_submission_pb2
def cal_polygon_contour(x, y, theta, width, length):
left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_front = (left_front_x, left_front_y)
right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_front = (right_front_x, right_front_y)
right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_back = (right_back_x, right_back_y)
left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_back = (left_back_x, left_back_y)
polygon_contour = [left_front, right_front, right_back, left_back]
return polygon_contour
def joint_scene_from_states(states, object_ids) -> sim_agents_submission_pb2.JointScene:
states = states.numpy()
simulated_trajectories = []
for i_object in range(len(object_ids)):
simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory(
center_x=states[i_object, :, 0], center_y=states[i_object, :, 1],
center_z=states[i_object, :, 2], heading=states[i_object, :, 3],
object_id=object_ids[i_object].item()
))
return sim_agents_submission_pb2.JointScene(simulated_trajectories=simulated_trajectories)
class SMART(pl.LightningModule):
def __init__(self, model_config) -> None:
super(SMART, self).__init__()
self.save_hyperparameters()
self.model_config = model_config
self.warmup_steps = model_config.warmup_steps
self.lr = model_config.lr
self.total_steps = model_config.total_steps
self.dataset = model_config.dataset
self.input_dim = model_config.input_dim
self.hidden_dim = model_config.hidden_dim
self.output_dim = model_config.output_dim
self.output_head = model_config.output_head
self.num_historical_steps = model_config.num_historical_steps
self.num_future_steps = model_config.decoder.num_future_steps
self.num_freq_bands = model_config.num_freq_bands
self.vis_map = False
self.noise = True
module_dir = os.path.dirname(os.path.dirname(__file__))
self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
self.init_map_token()
self.token_path = os.path.join(module_dir, 'tokens/cluster_frame_5_2048.pkl')
token_data = self.get_trajectory_token()
self.encoder = SMARTDecoder(
dataset=model_config.dataset,
input_dim=model_config.input_dim,
hidden_dim=model_config.hidden_dim,
num_historical_steps=model_config.num_historical_steps,
num_freq_bands=model_config.num_freq_bands,
num_heads=model_config.num_heads,
head_dim=model_config.head_dim,
dropout=model_config.dropout,
num_map_layers=model_config.decoder.num_map_layers,
num_agent_layers=model_config.decoder.num_agent_layers,
pl2pl_radius=model_config.decoder.pl2pl_radius,
pl2a_radius=model_config.decoder.pl2a_radius,
a2a_radius=model_config.decoder.a2a_radius,
time_span=model_config.decoder.time_span,
map_token={'traj_src': self.map_token['traj_src']},
token_data=token_data,
token_size=model_config.decoder.token_size
)
self.minADE = minADE(max_guesses=1)
self.minFDE = minFDE(max_guesses=1)
self.TokenCls = TokenCls(max_guesses=1)
self.test_predictions = dict()
self.cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
self.map_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
self.inference_token = False
self.rollout_num = 1
def get_trajectory_token(self):
token_data = pickle.load(open(self.token_path, 'rb'))
self.trajectory_token = token_data['token']
self.trajectory_token_traj = token_data['traj']
self.trajectory_token_all = token_data['token_all']
return token_data
def init_map_token(self):
self.argmin_sample_len = 3
map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
self.map_token = {'traj_src': map_token_traj['traj_src'], }
traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1],
self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0])
indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long()
self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float)
self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float)
def forward(self, data: HeteroData):
res = self.encoder(data)
return res
def inference(self, data: HeteroData):
res = self.encoder.inference(data)
return res
def maybe_autocast(self, dtype=torch.float16):
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
def training_step(self,
data,
batch_idx):
data = self.match_token_map(data)
data = self.sample_pt_pred(data)
if isinstance(data, Batch):
data['agent']['av_index'] += data['agent']['ptr'][:-1]
pred = self(data)
next_token_prob = pred['next_token_prob']
next_token_idx_gt = pred['next_token_idx_gt']
next_token_eval_mask = pred['next_token_eval_mask']
cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
loss = cls_loss
self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
self.log('cls_loss', cls_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
return loss
def validation_step(self,
data,
batch_idx):
data = self.match_token_map(data)
data = self.sample_pt_pred(data)
if isinstance(data, Batch):
data['agent']['av_index'] += data['agent']['ptr'][:-1]
pred = self(data)
next_token_idx = pred['next_token_idx']
next_token_idx_gt = pred['next_token_idx_gt']
next_token_eval_mask = pred['next_token_eval_mask']
next_token_prob = pred['next_token_prob']
cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
loss = cls_loss
self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
valid_mask=next_token_eval_mask[next_token_eval_mask])
self.log('val_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1] # * (data['agent']['category'] == 3)
if self.inference_token:
pred = self.inference(data)
pos_a = pred['pos_a']
gt = pred['gt']
valid_mask = data['agent']['valid_mask'][:, self.num_historical_steps:]
pred_traj = pred['pred_traj']
# next_token_idx = pred['next_token_idx'][..., None]
# next_token_idx_gt = pred['next_token_idx_gt'][:, 2:]
# next_token_eval_mask = pred['next_token_eval_mask'][:, 2:]
# next_token_eval_mask[:, 1:] = False
# self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
# valid_mask=next_token_eval_mask[next_token_eval_mask])
# self.log('val_inference_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1]
self.minADE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])
self.minFDE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])
# print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute())
self.log('val_minADE', self.minADE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)
self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)
def on_validation_start(self):
self.gt = []
self.pred = []
self.scenario_rollouts = []
self.batch_metric = defaultdict(list)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
def lr_lambda(current_step):
if current_step + 1 < self.warmup_steps:
return float(current_step + 1) / float(max(1, self.warmup_steps))
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))))
)
lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return [optimizer], [lr_scheduler]
def load_params_from_file(self, filename, logger, to_cpu=False):
if not os.path.isfile(filename):
raise FileNotFoundError
logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
loc_type = torch.device('cpu') if to_cpu else None
checkpoint = torch.load(filename, map_location=loc_type)
model_state_disk = checkpoint['state_dict']
version = checkpoint.get("version", None)
if version is not None:
logger.info('==> Checkpoint trained from version: %s' % version)
logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}')
model_state = self.state_dict()
model_state_disk_filter = {}
for key, val in model_state_disk.items():
if key in model_state and model_state_disk[key].shape == model_state[key].shape:
model_state_disk_filter[key] = val
else:
if key not in model_state:
print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}')
else:
print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}')
model_state_disk = model_state_disk_filter
missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False)
logger.info(f'Missing keys: {missing_keys}')
logger.info(f'The number of missing keys: {len(missing_keys)}')
logger.info(f'The number of unexpected keys: {len(unexpected_keys)}')
logger.info('==> Done (total keys %d)' % (len(model_state)))
epoch = checkpoint.get('epoch', -1)
it = checkpoint.get('it', 0.0)
return it, epoch
def match_token_map(self, data):
traj_pos = data['map_save']['traj_pos'].to(torch.float)
traj_theta = data['map_save']['traj_theta'].to(torch.float)
pl_idx_list = data['map_save']['pl_idx_list']
token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device)
token_src = self.map_token['traj_src'].to(traj_pos.device)
max_traj_len = self.map_token['traj_src'].shape[1]
pl_num = traj_pos.shape[0]
pt_token_pos = traj_pos[:, 0, :].clone()
pt_token_orientation = traj_theta.clone()
cos, sin = traj_theta.cos(), traj_theta.sin()
rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
rot_mat[..., 0, 0] = cos
rot_mat[..., 0, 1] = -sin
rot_mat[..., 1, 0] = sin
rot_mat[..., 1, 1] = cos
traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1))
pt_token_id = torch.argmin(distance, dim=1)
if self.noise:
topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]
sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)
cos, sin = traj_theta.cos(), traj_theta.sin()
rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
rot_mat[..., 0, 0] = cos
rot_mat[..., 0, 1] = sin
rot_mat[..., 1, 0] = -sin
rot_mat[..., 1, 1] = cos
token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)
pl_idx_full = pl_idx_list.clone()
token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
count_nums = []
for pl in pl_idx_full.unique():
pt = token2pl[0, token2pl[1, :] == pl]
left_side = (data['pt_token']['side'][pt] == 0).sum()
right_side = (data['pt_token']['side'][pt] == 1).sum()
center_side = (data['pt_token']['side'][pt] == 2).sum()
count_nums.append(torch.Tensor([left_side, right_side, center_side]))
count_nums = torch.stack(count_nums, dim=0)
num_polyline = int(count_nums.max().item())
traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1) #
counts_num_expanded = count_nums.unsqueeze(-1)
mask_update = idx_matrix < counts_num_expanded
traj_mask[mask_update] = True
data['pt_token']['traj_mask'] = traj_mask
data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
device=traj_pos.device, dtype=torch.float)], dim=-1)
data['pt_token']['orientation'] = pt_token_orientation
data['pt_token']['height'] = data['pt_token']['position'][:, -1]
data[('pt_token', 'to', 'map_polygon')] = {}
data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl
data['pt_token']['token_idx'] = pt_token_id
return data
def sample_pt_pred(self, data):
traj_mask = data['pt_token']['traj_mask']
raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)
masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)]
masked_pt_index = torch.sort(masked_pt_index, -1)[0]
pt_valid_mask = traj_mask.clone()
pt_valid_mask.scatter_(2, masked_pt_index, False)
pt_pred_mask = traj_mask.clone()
pt_pred_mask.scatter_(2, masked_pt_index, False)
tmp_mask = pt_pred_mask.clone()
tmp_mask[:, :, :] = True
tmp_mask.scatter_(2, masked_pt_index-1, False)
pt_pred_mask.masked_fill_(tmp_mask, False)
pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)
pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)
data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]
data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]
data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]
return data
================================================
FILE: smart/modules/__init__.py
================================================
from smart.modules.smart_decoder import SMARTDecoder
from smart.modules.map_decoder import SMARTMapDecoder
from smart.modules.agent_decoder import SMARTAgentDecoder
================================================
FILE: smart/modules/agent_decoder.py
================================================
import pickle
from typing import Dict, Mapping, Optional
import torch
import torch.nn as nn
from smart.layers import MLPLayer
from smart.layers.attention_layer import AttentionLayer
from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
from torch_cluster import radius, radius_graph
from torch_geometric.data import Batch, HeteroData
from torch_geometric.utils import dense_to_sparse, subgraph
from smart.utils import angle_between_2d_vectors, weight_init, wrap_angle
import math
def cal_polygon_contour(x, y, theta, width, length):
left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_front = (left_front_x, left_front_y)
right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_front = (right_front_x, right_front_y)
right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_back = (right_back_x, right_back_y)
left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_back = (left_back_x, left_back_y)
polygon_contour = [left_front, right_front, right_back, left_back]
return polygon_contour
class SMARTAgentDecoder(nn.Module):
def __init__(self,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
time_span: Optional[int],
pl2a_radius: float,
a2a_radius: float,
num_freq_bands: int,
num_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
token_data: Dict,
token_size=512) -> None:
super(SMARTAgentDecoder, self).__init__()
self.dataset = dataset
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_historical_steps = num_historical_steps
self.time_span = time_span if time_span is not None else num_historical_steps
self.pl2a_radius = pl2a_radius
self.a2a_radius = a2a_radius
self.num_freq_bands = num_freq_bands
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
input_dim_x_a = 2
input_dim_r_t = 4
input_dim_r_pt2a = 3
input_dim_r_a2a = 3
input_dim_token = 8
self.type_a_emb = nn.Embedding(4, hidden_dim)
self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim)
self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim)
self.t_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.pt2a_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=True, has_pos_emb=True) for _ in range(num_layers)]
)
self.a2a_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.token_size = token_size
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.token_size)
self.trajectory_token = token_data['token']
self.trajectory_token_traj = token_data['traj']
self.trajectory_token_all = token_data['token_all']
self.apply(weight_init)
self.shift = 5
self.beam_size = 5
self.hist_mask = True
def transform_rel(self, token_traj, prev_pos, prev_heading=None):
if prev_heading is None:
diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
num_agent, num_step, traj_num, traj_dim = token_traj.shape
cos, sin = prev_heading.cos(), prev_heading.sin()
rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
rot_mat[:, :, 0, 0] = cos
rot_mat[:, :, 0, 1] = -sin
rot_mat[:, :, 1, 0] = sin
rot_mat[:, :, 1, 1] = cos
agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
return agent_pred_rel
def agent_token_embedding(self, data, agent_category, agent_token_index, pos_a, head_vector_a, inference=False):
num_agent, num_step, traj_dim = pos_a.shape
motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),
pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
agent_type = data['agent']['type']
veh_mask = (agent_type == 0)
cyc_mask = (agent_type == 2)
ped_mask = (agent_type == 1)
trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1))
trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))
trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))
if inference:
agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(
torch.float)
trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(
torch.float)
trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(
torch.float)
agent_token_traj_all[veh_mask] = torch.cat(
[trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
agent_token_traj_all[ped_mask] = torch.cat(
[trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
agent_token_traj_all[cyc_mask] = torch.cat(
[trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]
agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]
agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]
agent_token_traj = torch.zeros((num_agent, num_step, self.token_size, 4, 2), device=pos_a.device)
agent_token_traj[veh_mask] = trajectory_token_veh
agent_token_traj[ped_mask] = trajectory_token_ped
agent_token_traj[cyc_mask] = trajectory_token_cyc
vel = data['agent']['token_velocity']
categorical_embs = [
self.type_a_emb(data['agent']['type'].long()).repeat_interleave(repeats=num_step,
dim=0),
self.shape_emb(data['agent']['shape'][:, self.num_historical_steps - 1, :]).repeat_interleave(
repeats=num_step,
dim=0)
]
feature_a = torch.stack(
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),
], dim=-1)
x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
categorical_embs=categorical_embs)
x_a = x_a.view(-1, num_step, self.hidden_dim)
feat_a = torch.cat((agent_token_emb, x_a), dim=-1)
feat_a = self.fusion_emb(feat_a)
if inference:
return feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs
else:
return feat_a, agent_token_traj
def agent_predict_next(self, data, agent_category, feat_a):
num_agent, num_step, traj_dim = data['agent']['token_pos'].shape
agent_type = data['agent']['type']
veh_mask = (agent_type == 0) # * agent_category==3
cyc_mask = (agent_type == 2) # * agent_category==3
ped_mask = (agent_type == 1) # * agent_category==3
token_res = torch.zeros((num_agent, num_step, self.token_size), device=agent_category.device)
token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])
token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])
token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])
return token_res
def agent_predict_next_inf(self, data, agent_category, feat_a):
num_agent, traj_dim = feat_a.shape
agent_type = data['agent']['type']
veh_mask = (agent_type == 0) # * agent_category==3
cyc_mask = (agent_type == 2) # * agent_category==3
ped_mask = (agent_type == 1) # * agent_category==3
token_res = torch.zeros((num_agent, self.token_size), device=agent_category.device)
token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])
token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])
token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])
return token_res
def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent, mask, inference_mask=None):
pos_t = pos_a.reshape(-1, self.input_dim)
head_t = head_a.reshape(-1)
head_vector_t = head_vector_a.reshape(-1, 2)
hist_mask = mask.clone()
if self.hist_mask and self.training:
hist_mask[
torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
elif inference_mask is not None:
mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
else:
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
edge_index_t = dense_to_sparse(mask_t)[0]
edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]]
edge_index_t = edge_index_t[:, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift]
rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
r_t = torch.stack(
[torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
rel_head_t,
edge_index_t[0] - edge_index_t[1]], dim=-1)
r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
return edge_index_t, r_t
def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s, mask_s):
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
max_num_neighbors=300)
edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0]
rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
r_a2a = torch.stack(
[torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
rel_head_a2a], dim=-1)
r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
return edge_index_a2a, r_a2a
def build_map2agent_edge(self, data, num_step, agent_category, pos_a, head_a, head_vector_a, mask,
batch_s, batch_pl):
mask_pl2a = mask.clone()
mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
orient_pl = data['pt_token']['orientation'].contiguous()
pos_pl = pos_pl.repeat(num_step, 1)
orient_pl = orient_pl.repeat(num_step)
edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]]
rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
r_pl2a = torch.stack(
[torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
rel_orient_pl2a], dim=-1)
r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
return edge_index_pl2a, r_pl2a
def forward(self,
data: HeteroData,
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
pos_a = data['agent']['token_pos']
head_a = data['agent']['token_heading']
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
num_agent, num_step, traj_dim = pos_a.shape
agent_category = data['agent']['category']
agent_token_index = data['agent']['token_idx']
feat_a, agent_token_traj = self.agent_token_embedding(data, agent_category, agent_token_index,
pos_a, head_vector_a)
agent_valid_mask = data['agent']['agent_valid_mask'].clone()
# eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
# agent_valid_mask[~eval_mask] = False
mask = agent_valid_mask
edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask)
if isinstance(data, Batch):
batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t
for t in range(num_step)], dim=0)
batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t
for t in range(num_step)], dim=0)
else:
batch_s = torch.arange(num_step,
device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])
batch_pl = torch.arange(num_step,
device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
mask_s = mask.transpose(0, 1).reshape(-1)
edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, batch_s, mask_s)
mask[agent_category != 3] = False
edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,
head_vector_a, mask, batch_s, batch_pl)
for i in range(self.num_layers):
feat_a = feat_a.reshape(-1, self.hidden_dim)
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
feat_a = feat_a.reshape(-1, num_step,
self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(
repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
num_agent, num_step, hidden_dim, traj_num, traj_dim = agent_token_traj.shape
next_token_prob = self.token_predict_head(feat_a)
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
next_token_eval_mask = mask.clone()
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)
next_token_eval_mask[:, -1] = False
return {'x_a': feat_a,
'next_token_idx': next_token_idx,
'next_token_prob': next_token_prob,
'next_token_idx_gt': next_token_index_gt,
'next_token_eval_mask': next_token_eval_mask,
}
def inference(self,
data: HeteroData,
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
pos_a = data['agent']['token_pos'].clone()
head_a = data['agent']['token_heading'].clone()
num_agent, num_step, traj_dim = pos_a.shape
pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
agent_valid_mask = data['agent']['agent_valid_mask'].clone()
agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
agent_valid_mask[~eval_mask] = False
agent_token_index = data['agent']['token_idx']
agent_category = data['agent']['category']
feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding(
data,
agent_category,
agent_token_index,
pos_a,
head_vector_a,
inference=True)
agent_type = data["agent"]["type"]
veh_mask = (agent_type == 0) # * agent_category==3
cyc_mask = (agent_type == 2) # * agent_category==3
ped_mask = (agent_type == 1) # * agent_category==3
av_mask = data["agent"]["av_index"]
self.num_recurrent_steps_val = data["agent"]['position'].shape[1]-self.num_historical_steps
pred_traj = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, 2, device=feat_a.device)
pred_head = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, device=feat_a.device)
pred_prob = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val // self.shift, device=feat_a.device)
next_token_idx_list = []
mask = agent_valid_mask.clone()
feat_a_t_dict = {}
for t in range(self.num_recurrent_steps_val // self.shift):
if t == 0:
inference_mask = mask.clone()
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False
else:
inference_mask = torch.zeros_like(mask)
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True
edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask, inference_mask)
if isinstance(data, Batch):
batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t
for t in range(num_step)], dim=0)
batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t
for t in range(num_step)], dim=0)
else:
batch_s = torch.arange(num_step,
device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])
batch_pl = torch.arange(num_step,
device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
# In the inference stage, we only infer the current stage for recurrent
edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,
head_vector_a,
inference_mask, batch_s,
batch_pl)
mask_s = inference_mask.transpose(0, 1).reshape(-1)
edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a,
batch_s, mask_s)
for i in range(self.num_layers):
if i in feat_a_t_dict:
feat_a = feat_a_t_dict[i]
feat_a = feat_a.reshape(-1, self.hidden_dim)
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
feat_a = feat_a.reshape(-1, num_step,
self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(
repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
if i+1 not in feat_a_t_dict:
feat_a_t_dict[i+1] = feat_a
else:
feat_a_t_dict[i+1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
topk_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1)
expanded_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)
next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_index)
theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
cos, sin = theta.cos(), theta.sin()
rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2),
rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view(
-1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2)
agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...]
sample_index = torch.multinomial(topk_prob, 1).to(agent_pred_rel.device)
agent_pred_rel = agent_pred_rel.gather(dim=1,
index=sample_index[..., None, None, None].expand(-1, -1, 6, 4,
2))[:, 0, ...]
pred_prob[:, t] = topk_prob.gather(dim=-1, index=sample_index)[:, 0]
pred_traj[:, t * 5:(t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2)
diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :]
pred_head[:, t * 5:(t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1)
diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :]
theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta
next_token_idx = next_token_idx.gather(dim=1, index=sample_index)
next_token_idx = next_token_idx.squeeze(-1)
next_token_idx_list.append(next_token_idx[:, None])
agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[
next_token_idx[veh_mask]]
agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[
next_token_idx[ped_mask]]
agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[
next_token_idx[cyc_mask]]
motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),
pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
vel = motion_vector_a.clone() / (0.1 * self.shift)
vel[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0
motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0
x_a = torch.stack(
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1)
x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)),
categorical_embs=categorical_embs)
x_a = x_a.view(-1, num_step, self.hidden_dim)
feat_a = torch.cat((agent_token_emb, x_a), dim=-1)
feat_a = self.fusion_emb(feat_a)
agent_valid_mask[agent_category != 3] = False
return {
'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:],
'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:],
'gt': data['agent']['position'][:, self.num_historical_steps:, :self.input_dim].contiguous(),
'valid_mask': agent_valid_mask[:, self.num_historical_steps:],
'pred_traj': pred_traj,
'pred_head': pred_head,
'next_token_idx': torch.cat(next_token_idx_list, dim=-1),
'next_token_idx_gt': agent_token_index.roll(shifts=-1, dims=1),
'next_token_eval_mask': data['agent']['agent_valid_mask'],
'pred_prob': pred_prob,
'vel': vel
}
================================================
FILE: smart/modules/map_decoder.py
================================================
import os.path
from typing import Dict
import torch
import torch.nn as nn
from torch_cluster import radius_graph
from torch_geometric.data import Batch
from torch_geometric.data import HeteroData
from torch_geometric.utils import dense_to_sparse, subgraph
from smart.utils.nan_checker import check_nan_inf
from smart.layers.attention_layer import AttentionLayer
from smart.layers import MLPLayer
from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
from smart.utils import angle_between_2d_vectors
from smart.utils import merge_edges
from smart.utils import weight_init
from smart.utils import wrap_angle
import pickle
class SMARTMapDecoder(nn.Module):
def __init__(self,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
pl2pl_radius: float,
num_freq_bands: int,
num_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
map_token) -> None:
super(SMARTMapDecoder, self).__init__()
self.dataset = dataset
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_historical_steps = num_historical_steps
self.pl2pl_radius = pl2pl_radius
self.num_freq_bands = num_freq_bands
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
if input_dim == 2:
input_dim_r_pt2pt = 3
elif input_dim == 3:
input_dim_r_pt2pt = 4
else:
raise ValueError('{} is not a valid dimension'.format(input_dim))
self.type_pt_emb = nn.Embedding(17, hidden_dim)
self.side_pt_emb = nn.Embedding(4, hidden_dim)
self.polygon_type_emb = nn.Embedding(4, hidden_dim)
self.light_pl_emb = nn.Embedding(4, hidden_dim)
self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.pt2pt_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.token_size = 1024
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.token_size)
input_dim_token = 22
self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.map_token = map_token
self.apply(weight_init)
self.mask_pt = False
def maybe_autocast(self, dtype=torch.float32):
return torch.cuda.amp.autocast(dtype=dtype)
def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
pt_valid_mask = data['pt_token']['pt_valid_mask']
pt_pred_mask = data['pt_token']['pt_pred_mask']
pt_target_mask = data['pt_token']['pt_target_mask']
mask_s = pt_valid_mask
pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous()
orient_pt = data['pt_token']['orientation'].contiguous()
orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1)
token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float)
pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1))
pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']]
if self.input_dim == 2:
x_pt = pt_token_emb
elif self.input_dim == 3:
x_pt = pt_token_emb
else:
raise ValueError('{} is not a valid dimension'.format(self.input_dim))
token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index']
token_light_type = data['map_polygon']['light_type'][token2pl[1]]
x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()),
self.polygon_type_emb(data['pt_token']['pl_type'].long()),
self.light_pl_emb(token_light_type.long()),]
x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0)
edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius,
batch=data['pt_token']['batch'] if isinstance(data, Batch) else None,
loop=False, max_num_neighbors=100)
if self.mask_pt:
edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0]
rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]]
rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]])
if self.input_dim == 2:
r_pt2pt = torch.stack(
[torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
nbr_vector=rel_pos_pt2pt[:, :2]),
rel_orient_pt2pt], dim=-1)
elif self.input_dim == 3:
r_pt2pt = torch.stack(
[torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
nbr_vector=rel_pos_pt2pt[:, :2]),
rel_pos_pt2pt[:, -1],
rel_orient_pt2pt], dim=-1)
else:
raise ValueError('{} is not a valid dimension'.format(self.input_dim))
r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None)
for i in range(self.num_layers):
x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt)
next_token_prob = self.token_predict_head(x_pt[pt_pred_mask])
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask]
return {
'x_pt': x_pt,
'map_next_token_idx': next_token_idx,
'map_next_token_prob': next_token_prob,
'map_next_token_idx_gt': next_token_index_gt,
'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask]
}
================================================
FILE: smart/modules/smart_decoder.py
================================================
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch_geometric.data import HeteroData
from smart.modules.agent_decoder import SMARTAgentDecoder
from smart.modules.map_decoder import SMARTMapDecoder
class SMARTDecoder(nn.Module):
def __init__(self,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
pl2pl_radius: float,
time_span: Optional[int],
pl2a_radius: float,
a2a_radius: float,
num_freq_bands: int,
num_map_layers: int,
num_agent_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
map_token: Dict,
token_data: Dict,
use_intention=False,
token_size=512) -> None:
super(SMARTDecoder, self).__init__()
self.map_encoder = SMARTMapDecoder(
dataset=dataset,
input_dim=input_dim,
hidden_dim=hidden_dim,
num_historical_steps=num_historical_steps,
pl2pl_radius=pl2pl_radius,
num_freq_bands=num_freq_bands,
num_layers=num_map_layers,
num_heads=num_heads,
head_dim=head_dim,
dropout=dropout,
map_token=map_token
)
self.agent_encoder = SMARTAgentDecoder(
dataset=dataset,
input_dim=input_dim,
hidden_dim=hidden_dim,
num_historical_steps=num_historical_steps,
time_span=time_span,
pl2a_radius=pl2a_radius,
a2a_radius=a2a_radius,
num_freq_bands=num_freq_bands,
num_layers=num_agent_layers,
num_heads=num_heads,
head_dim=head_dim,
dropout=dropout,
token_size=token_size,
token_data=token_data
)
self.map_enc = None
def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
map_enc = self.map_encoder(data)
agent_enc = self.agent_encoder(data, map_enc)
return {**map_enc, **agent_enc}
def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:
map_enc = self.map_encoder(data)
agent_enc = self.agent_encoder.inference(data, map_enc)
return {**map_enc, **agent_enc}
def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]:
agent_enc = self.agent_encoder.inference(data, map_enc)
return {**map_enc, **agent_enc}
================================================
FILE: smart/preprocess/__init__.py
================================================
================================================
FILE: smart/preprocess/preprocess.py
================================================
import numpy as np
import pandas as pd
import os
import torch
from typing import Any, Dict, List, Optional
predict_unseen_agents = False
vector_repr = True
_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']
_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
_point_sides = ['LEFT', 'RIGHT', 'CENTER']
_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
_polygon_is_intersections = [True, False, None]
Lane_type_hash = {
4: "BIKE",
3: "VEHICLE",
2: "VEHICLE",
1: "BUS"
}
boundary_type_hash = {
5: "UNKNOWN",
6: "DASHED_WHITE",
7: "SOLID_WHITE",
8: "DOUBLE_DASH_WHITE",
9: "DASHED_YELLOW",
10: "DOUBLE_DASH_YELLOW",
11: "SOLID_YELLOW",
12: "DOUBLE_SOLID_YELLOW",
13: "DASH_SOLID_YELLOW",
14: "UNKNOWN",
15: "EDGE",
16: "EDGE"
}
def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:
if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps
historical_df = df[df['timestep'] == num_historical_steps-1]
agent_ids = list(historical_df['track_id'].unique())
df = df[df['track_id'].isin(agent_ids)]
else:
agent_ids = list(df['track_id'].unique())
num_agents = len(agent_ids)
# initialization
valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
agent_id: List[Optional[str]] = [None] * num_agents
agent_type = torch.zeros(num_agents, dtype=torch.uint8)
agent_category = torch.zeros(num_agents, dtype=torch.uint8)
position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
for track_id, track_df in df.groupby('track_id'):
agent_idx = agent_ids.index(track_id)
agent_steps = track_df['timestep'].values
valid_mask[agent_idx, agent_steps] = True
current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]
predict_mask[agent_idx, agent_steps] = True
if vector_repr: # a time step t is valid only when both t and t-1 are valid
valid_mask[agent_idx, 1: num_historical_steps] = (
valid_mask[agent_idx, :num_historical_steps - 1] &
valid_mask[agent_idx, 1: num_historical_steps])
valid_mask[agent_idx, 0] = False
predict_mask[agent_idx, :num_historical_steps] = False
if not current_valid_mask[agent_idx]:
predict_mask[agent_idx, num_historical_steps:] = False
agent_id[agent_idx] = track_id
agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
agent_category[agent_idx] = track_df['object_category'].values[0]
position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,
track_df['position_y'].values,
track_df['position_z'].values],
axis=-1)).float()
heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()
velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,
track_df['velocity_y'].values],
axis=-1)).float()
shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,
track_df['width'].values,
track_df["height"].values],
axis=-1)).float()
av_idx = agent_id.index(av_id)
return {
'num_nodes': num_agents,
'av_index': av_idx,
'valid_mask': valid_mask,
'predict_mask': predict_mask,
'id': agent_id,
'type': agent_type,
'category': agent_category,
'position': position,
'heading': heading,
'velocity': velocity,
'shape': shape
}
================================================
FILE: smart/tokens/__init__.py
================================================
================================================
FILE: smart/transforms/__init__.py
================================================
from smart.transforms.target_builder import WaymoTargetBuilder
================================================
FILE: smart/transforms/target_builder.py
================================================
import numpy as np
import torch
from torch_geometric.data import HeteroData
from torch_geometric.transforms import BaseTransform
from smart.utils import wrap_angle
from smart.utils.log import Logging
def to_16(data):
if isinstance(data, dict):
for key, value in data.items():
new_value = to_16(value)
data[key] = new_value
if isinstance(data, torch.Tensor):
if data.dtype == torch.float32:
data = data.to(torch.float16)
return data
def tofloat32(data):
for name in data:
value = data[name]
if isinstance(value, dict):
value = tofloat32(value)
elif isinstance(value, torch.Tensor) and value.dtype == torch.float64:
value = value.to(torch.float32)
data[name] = value
return data
class WaymoTargetBuilder(BaseTransform):
def __init__(self,
num_historical_steps: int,
num_future_steps: int,
mode="train") -> None:
self.num_historical_steps = num_historical_steps
self.num_future_steps = num_future_steps
self.mode = mode
self.num_features = 3
self.augment = False
self.logger = Logging().log(level='DEBUG')
def score_ego_agent(self, agent):
av_index = agent['av_index']
agent["category"][av_index] = 5
return agent
def clip(self, agent, max_num=32):
av_index = agent["av_index"]
valid = agent['valid_mask']
ego_pos = agent["position"][av_index]
obstacle_mask = agent['type'] == 3
distance = torch.norm(agent["position"][:, self.num_historical_steps-1, :2] - ego_pos[self.num_historical_steps-1, :2], dim=-1) # keep the closest 100 vehicles near the ego car
distance[obstacle_mask] = 10e5
sort_idx = distance.sort()[1]
mask = torch.zeros(valid.shape[0])
mask[sort_idx[:max_num]] = 1
mask = mask.to(torch.bool)
mask[av_index] = True
new_av_index = mask[:av_index].sum()
agent["num_nodes"] = int(mask.sum())
agent["av_index"] = int(new_av_index)
excluded = ["num_nodes", "av_index", "ego"]
for key, val in agent.items():
if key in excluded:
continue
if key == "id":
val = list(np.array(val)[mask])
agent[key] = val
continue
if len(val.size()) > 1:
agent[key] = val[mask, ...]
else:
agent[key] = val[mask]
return agent
def score_nearby_vehicle(self, agent, max_num=10):
av_index = agent['av_index']
agent["category"] = torch.zeros_like(agent["category"])
obstacle_mask = agent['type'] == 3
pos = agent["position"][av_index, self.num_historical_steps, :2]
distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1)
distance[obstacle_mask] = 10e5
sort_idx = distance.sort()[1]
nearby_mask = torch.zeros(distance.shape[0])
nearby_mask[sort_idx[1:max_num]] = 1
nearby_mask = nearby_mask.bool()
agent["category"][nearby_mask] = 3
agent["category"][obstacle_mask] = 0
def score_trained_vehicle(self, agent, max_num=10, min_distance=0):
av_index = agent['av_index']
agent["category"] = torch.zeros_like(agent["category"])
pos = agent["position"][av_index, self.num_historical_steps, :2]
distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1)
distance_all_time = torch.norm(agent["position"][:, :, :2] - agent["position"][av_index, :, :2], dim=-1)
invalid_mask = distance_all_time < 150 # we do not believe the perception out of range of 150 meters
agent["valid_mask"] = agent["valid_mask"] * invalid_mask
# we do not predict vehicle too far away from ego car
closet_vehicle = distance < 100
valid = agent['valid_mask']
valid_current = valid[:, (self.num_historical_steps):]
valid_counts = valid_current.sum(1)
counts_vehicle = valid_counts >= 1
no_backgroud = agent['type'] != 3
vehicle2pred = closet_vehicle & counts_vehicle & no_backgroud
if vehicle2pred.sum() > max_num:
# too many still vehicle so that train the model using the moving vehicle as much as possible
true_indices = torch.nonzero(vehicle2pred).squeeze(1)
selected_indices = true_indices[torch.randperm(true_indices.size(0))[:max_num]]
vehicle2pred.fill_(False)
vehicle2pred[selected_indices] = True
agent["category"][vehicle2pred] = 3
def rotate_agents(self, position, heading, num_nodes, num_historical_steps, num_future_steps):
origin = position[:, num_historical_steps - 1]
theta = heading[:, num_historical_steps - 1]
cos, sin = theta.cos(), theta.sin()
rot_mat = theta.new_zeros(num_nodes, 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = -sin
rot_mat[:, 1, 0] = sin
rot_mat[:, 1, 1] = cos
target = origin.new_zeros(num_nodes, num_future_steps, 4)
target[..., :2] = torch.bmm(position[:, num_historical_steps:, :2] -
origin[:, :2].unsqueeze(1), rot_mat)
his = origin.new_zeros(num_nodes, num_historical_steps, 4)
his[..., :2] = torch.bmm(position[:, :num_historical_steps, :2] -
origin[:, :2].unsqueeze(1), rot_mat)
if position.size(2) == 3:
target[..., 2] = (position[:, num_historical_steps:, 2] -
origin[:, 2].unsqueeze(-1))
his[..., 2] = (position[:, :num_historical_steps, 2] -
origin[:, 2].unsqueeze(-1))
target[..., 3] = wrap_angle(heading[:, num_historical_steps:] -
theta.unsqueeze(-1))
his[..., 3] = wrap_angle(heading[:, :num_historical_steps] -
theta.unsqueeze(-1))
else:
target[..., 2] = wrap_angle(heading[:, num_historical_steps:] -
theta.unsqueeze(-1))
his[..., 2] = wrap_angle(heading[:, :num_historical_steps] -
theta.unsqueeze(-1))
return his, target
def __call__(self, data) -> HeteroData:
agent = data["agent"]
self.score_ego_agent(agent)
self.score_trained_vehicle(agent, max_num=32)
return HeteroData(data)
================================================
FILE: smart/utils/__init__.py
================================================
from smart.utils.geometry import angle_between_2d_vectors
from smart.utils.geometry import angle_between_3d_vectors
from smart.utils.geometry import side_to_directed_lineseg
from smart.utils.geometry import wrap_angle
from smart.utils.graph import add_edges
from smart.utils.graph import bipartite_dense_to_sparse
from smart.utils.graph import complete_graph
from smart.utils.graph import merge_edges
from smart.utils.graph import unbatch
from smart.utils.list import safe_list_index
from smart.utils.weight_init import weight_init
================================================
FILE: smart/utils/cluster_reader.py
================================================
import io
import pickle
import pandas as pd
import json
class LoadScenarioFromCeph:
def __init__(self):
from petrel_client.client import Client
self.file_client = Client('~/petreloss.conf')
def list(self, dir_path):
return list(self.file_client.list(dir_path))
def save(self, data, url):
self.file_client.put(url, pickle.dumps(data))
def read_correct_csv(self, scenario_path):
output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python")
return output
def contains(self, url):
return self.file_client.contains(url)
def read_string(self, csv_url):
from io import StringIO
df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False)
return df
def read(self, scenario_path):
with io.BytesIO(self.file_client.get(scenario_path)) as f:
datas = pickle.load(f)
return datas
d
gitextract_98xwsw9y/ ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── configs/ │ ├── train/ │ │ └── train_scalable.yaml │ └── validation/ │ └── validation_scalable.yaml ├── data_preprocess.py ├── environment.yml ├── pyproject.toml ├── requirements.txt ├── scripts/ │ ├── install_pyg.sh │ └── traj_clstering.py ├── smart/ │ ├── __init__.py │ ├── datamodules/ │ │ ├── __init__.py │ │ └── scalable_datamodule.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── preprocess.py │ │ └── scalable_dataset.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── attention_layer.py │ │ ├── fourier_embedding.py │ │ └── mlp_layer.py │ ├── metrics/ │ │ ├── __init__.py │ │ ├── average_meter.py │ │ ├── min_ade.py │ │ ├── min_fde.py │ │ ├── next_token_cls.py │ │ └── utils.py │ ├── model/ │ │ ├── __init__.py │ │ └── smart.py │ ├── modules/ │ │ ├── __init__.py │ │ ├── agent_decoder.py │ │ ├── map_decoder.py │ │ └── smart_decoder.py │ ├── preprocess/ │ │ ├── __init__.py │ │ └── preprocess.py │ ├── tokens/ │ │ ├── __init__.py │ │ ├── cluster_frame_5_2048.pkl │ │ └── map_traj_token5.pkl │ ├── transforms/ │ │ ├── __init__.py │ │ └── target_builder.py │ └── utils/ │ ├── __init__.py │ ├── cluster_reader.py │ ├── config.py │ ├── geometry.py │ ├── graph.py │ ├── list.py │ ├── log.py │ ├── nan_checker.py │ └── weight_init.py ├── train.py └── val.py
SYMBOL INDEX (167 symbols across 27 files)
FILE: data_preprocess.py
function safe_list_index (line 65) | def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
function get_agent_features (line 72) | def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10,...
function get_map_features (line 145) | def get_map_features(map_infos, tf_current_light, dim=3):
function process_agent (line 335) | def process_agent(track_info, tracks_to_predict, sdc_track_index, scenar...
function process_dynamic_map (line 382) | def process_dynamic_map(dynamic_map_infos):
function decode_tracks_from_proto (line 466) | def decode_tracks_from_proto(tracks):
function decode_map_features_from_proto (line 488) | def decode_map_features_from_proto(map_features):
function decode_dynamic_map_states_from_proto (line 610) | def decode_dynamic_map_states_from_proto(dynamic_map_states):
function process_single_data (line 630) | def process_single_data(scenario):
function wm2argo (line 663) | def wm2argo(file, dir_name, output_dir):
function batch_process9s_transformer (line 693) | def batch_process9s_transformer(dir_name, output_dir, num_workers=2):
FILE: scripts/traj_clstering.py
function average_distance_vectorized (line 5) | def average_distance_vectorized(point_set1, centroids):
function assign_clusters (line 10) | def assign_clusters(sub_X, centroids):
function Kdisk_cluster (line 15) | def Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None):
function cal_polygon_contour (line 58) | def cal_polygon_contour(x, y, theta, width, length):
FILE: smart/datamodules/scalable_datamodule.py
class MultiDataModule (line 9) | class MultiDataModule(pl.LightningDataModule):
method __init__ (line 18) | def __init__(self,
method setup (line 69) | def setup(self, stage: Optional[str] = None) -> None:
method train_dataloader (line 77) | def train_dataloader(self):
method val_dataloader (line 82) | def val_dataloader(self):
method test_dataloader (line 87) | def test_dataloader(self):
FILE: smart/datasets/preprocess.py
function cal_polygon_contour (line 10) | def cal_polygon_contour(x, y, theta, width, length):
function interplating_polyline (line 33) | def interplating_polyline(polylines, heading, distance=0.5, split_distac...
function average_distance_vectorized (line 117) | def average_distance_vectorized(point_set1, centroids):
function assign_clusters (line 122) | def assign_clusters(sub_X, centroids):
class TokenProcessor (line 127) | class TokenProcessor:
method __init__ (line 129) | def __init__(self, token_size):
method preprocess (line 140) | def preprocess(self, data):
method get_trajectory_token (line 150) | def get_trajectory_token(self):
method clean_heading (line 171) | def clean_heading(self, data):
method tokenize_agent (line 195) | def tokenize_agent(self, data):
method match_token (line 295) | def match_token(self, pos, valid_mask, heading, category, agent_catego...
method tokenize_map (line 403) | def tokenize_map(self, data):
FILE: smart/datasets/scalable_dataset.py
function distance (line 11) | def distance(point1, point2):
class MultiDataset (line 15) | class MultiDataset(Dataset):
method __init__ (line 16) | def __init__(self,
method raw_dir (line 66) | def raw_dir(self) -> str:
method raw_paths (line 70) | def raw_paths(self) -> List[str]:
method raw_file_names (line 74) | def raw_file_names(self) -> Union[str, List[str], Tuple]:
method processed_file_names (line 78) | def processed_file_names(self) -> Union[str, List[str], Tuple]:
method len (line 81) | def len(self) -> int:
method generate_ref_token (line 84) | def generate_ref_token(self):
method get (line 87) | def get(self, idx: int):
FILE: smart/layers/attention_layer.py
class AttentionLayer (line 12) | class AttentionLayer(MessagePassing):
method __init__ (line 14) | def __init__(self,
method forward (line 57) | def forward(self,
method message (line 74) | def message(self,
method update (line 90) | def update(self,
method _attn_block (line 97) | def _attn_block(self,
method _ff_block (line 108) | def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
FILE: smart/layers/fourier_embedding.py
class FourierEmbedding (line 9) | class FourierEmbedding(nn.Module):
method __init__ (line 11) | def __init__(self,
method forward (line 35) | def forward(self,
class MLPEmbedding (line 56) | class MLPEmbedding(nn.Module):
method __init__ (line 57) | def __init__(self,
method forward (line 73) | def forward(self,
FILE: smart/layers/mlp_layer.py
class MLPLayer (line 8) | class MLPLayer(nn.Module):
method __init__ (line 10) | def __init__(self,
method forward (line 23) | def forward(self, x: torch.Tensor) -> torch.Tensor:
FILE: smart/metrics/average_meter.py
class AverageMeter (line 6) | class AverageMeter(Metric):
method __init__ (line 8) | def __init__(self, **kwargs) -> None:
method update (line 13) | def update(self, val: torch.Tensor) -> None:
method compute (line 17) | def compute(self) -> torch.Tensor:
FILE: smart/metrics/min_ade.py
class minMultiADE (line 11) | class minMultiADE(Metric):
method __init__ (line 13) | def __init__(self,
method update (line 21) | def update(self,
method compute (line 44) | def compute(self) -> torch.Tensor:
class minADE (line 48) | class minADE(Metric):
method __init__ (line 50) | def __init__(self,
method update (line 59) | def update(self,
method compute (line 84) | def compute(self) -> torch.Tensor:
FILE: smart/metrics/min_fde.py
class minMultiFDE (line 10) | class minMultiFDE(Metric):
method __init__ (line 12) | def __init__(self,
method update (line 20) | def update(self,
method compute (line 34) | def compute(self) -> torch.Tensor:
class minFDE (line 38) | class minFDE(Metric):
method __init__ (line 40) | def __init__(self,
method update (line 49) | def update(self,
method compute (line 60) | def compute(self) -> torch.Tensor:
FILE: smart/metrics/next_token_cls.py
class TokenCls (line 10) | class TokenCls(Metric):
method __init__ (line 12) | def __init__(self,
method update (line 20) | def update(self,
method compute (line 29) | def compute(self) -> torch.Tensor:
FILE: smart/metrics/utils.py
function topk (line 8) | def topk(
function topkind (line 44) | def topkind(
function valid_filter (line 80) | def valid_filter(
function new_batch_nms (line 108) | def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):
function batch_nms (line 163) | def batch_nms(pred_trajs, pred_scores,
function batch_nms_token (line 224) | def batch_nms_token(pred_trajs, pred_scores,
FILE: smart/model/smart.py
function cal_polygon_contour (line 20) | def cal_polygon_contour(x, y, theta, width, length):
function joint_scene_from_states (line 41) | def joint_scene_from_states(states, object_ids) -> sim_agents_submission...
class SMART (line 53) | class SMART(pl.LightningModule):
method __init__ (line 55) | def __init__(self, model_config) -> None:
method get_trajectory_token (line 106) | def get_trajectory_token(self):
method init_map_token (line 113) | def init_map_token(self):
method forward (line 124) | def forward(self, data: HeteroData):
method inference (line 128) | def inference(self, data: HeteroData):
method maybe_autocast (line 132) | def maybe_autocast(self, dtype=torch.float16):
method training_step (line 140) | def training_step(self,
method validation_step (line 157) | def validation_step(self,
method on_validation_start (line 199) | def on_validation_start(self):
method configure_optimizers (line 205) | def configure_optimizers(self):
method load_params_from_file (line 218) | def load_params_from_file(self, filename, logger, to_cpu=False):
method match_token_map (line 257) | def match_token_map(self, data):
method sample_pt_pred (line 321) | def sample_pt_pred(self, data):
FILE: smart/modules/agent_decoder.py
function cal_polygon_contour (line 15) | def cal_polygon_contour(x, y, theta, width, length):
class SMARTAgentDecoder (line 36) | class SMARTAgentDecoder(nn.Module):
method __init__ (line 38) | def __init__(self,
method transform_rel (line 110) | def transform_rel(self, token_traj, prev_pos, prev_heading=None):
method agent_token_embedding (line 126) | def agent_token_embedding(self, data, agent_category, agent_token_inde...
method agent_predict_next (line 194) | def agent_predict_next(self, data, agent_category, feat_a):
method agent_predict_next_inf (line 206) | def agent_predict_next_inf(self, data, agent_category, feat_a):
method build_temporal_edge (line 221) | def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent,...
method build_interaction_edge (line 249) | def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s...
method build_map2agent_edge (line 265) | def build_map2agent_edge(self, data, num_step, agent_category, pos_a, ...
method forward (line 288) | def forward(self,
method inference (line 351) | def inference(self,
FILE: smart/modules/map_decoder.py
class SMARTMapDecoder (line 20) | class SMARTMapDecoder(nn.Module):
method __init__ (line 22) | def __init__(self,
method maybe_autocast (line 73) | def maybe_autocast(self, dtype=torch.float32):
method forward (line 76) | def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
FILE: smart/modules/smart_decoder.py
class SMARTDecoder (line 9) | class SMARTDecoder(nn.Module):
method __init__ (line 11) | def __init__(self,
method forward (line 62) | def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
method inference (line 67) | def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:
method inference_no_map (line 72) | def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, tor...
FILE: smart/preprocess/preprocess.py
function get_agent_features (line 44) | def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10,...
FILE: smart/transforms/target_builder.py
function to_16 (line 10) | def to_16(data):
function tofloat32 (line 21) | def tofloat32(data):
class WaymoTargetBuilder (line 32) | class WaymoTargetBuilder(BaseTransform):
method __init__ (line 34) | def __init__(self,
method score_ego_agent (line 45) | def score_ego_agent(self, agent):
method clip (line 50) | def clip(self, agent, max_num=32):
method score_nearby_vehicle (line 79) | def score_nearby_vehicle(self, agent, max_num=10):
method score_trained_vehicle (line 93) | def score_trained_vehicle(self, agent, max_num=10, min_distance=0):
method rotate_agents (line 117) | def rotate_agents(self, position, heading, num_nodes, num_historical_s...
method __call__ (line 148) | def __call__(self, data) -> HeteroData:
FILE: smart/utils/cluster_reader.py
class LoadScenarioFromCeph (line 7) | class LoadScenarioFromCeph:
method __init__ (line 8) | def __init__(self):
method list (line 12) | def list(self, dir_path):
method save (line 15) | def save(self, data, url):
method read_correct_csv (line 18) | def read_correct_csv(self, scenario_path):
method contains (line 22) | def contains(self, url):
method read_string (line 25) | def read_string(self, csv_url):
method read (line 30) | def read(self, scenario_path):
method read_json (line 35) | def read_json(self, path):
method read_csv (line 40) | def read_csv(self, scenario_path):
method read_model (line 43) | def read_model(self, model_path):
FILE: smart/utils/config.py
function load_config_act (line 6) | def load_config_act(path):
function load_config_init (line 13) | def load_config_init(path):
FILE: smart/utils/geometry.py
function angle_between_2d_vectors (line 7) | def angle_between_2d_vectors(
function angle_between_3d_vectors (line 14) | def angle_between_3d_vectors(
function side_to_directed_lineseg (line 21) | def side_to_directed_lineseg(
function wrap_angle (line 35) | def wrap_angle(
FILE: smart/utils/graph.py
function add_edges (line 9) | def add_edges(
function merge_edges (line 33) | def merge_edges(
function complete_graph (line 45) | def complete_graph(
function bipartite_dense_to_sparse (line 76) | def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor:
function unbatch (line 85) | def unbatch(
FILE: smart/utils/list.py
function safe_list_index (line 5) | def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
FILE: smart/utils/log.py
class Logging (line 6) | class Logging:
method make_log_dir (line 8) | def make_log_dir(self, dirname='logs'):
method get_log_filename (line 16) | def get_log_filename(self):
method log (line 22) | def log(self, level='DEBUG', name="simagent"):
method add_log (line 36) | def add_log(self, logger, level='DEBUG'):
FILE: smart/utils/nan_checker.py
function check_nan_inf (line 3) | def check_nan_inf(t, s):
FILE: smart/utils/weight_init.py
function weight_init (line 5) | def weight_init(m: nn.Module) -> None:
Condensed preview — 52 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (218K chars).
[
{
"path": ".gitignore",
"chars": 1576,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n.github\nckpt/\n# assets/\n# C extensions\n*.so\n# "
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 4930,
"preview": "<div align=\"center\">\n \n # SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction\n \n [Pape"
},
{
"path": "__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "configs/train/train_scalable.yaml",
"chars": 1215,
"preview": "# Config format schema number, the yaml support to valid case source from different dataset\ntime_info: &time_info\n num_"
},
{
"path": "configs/validation/validation_scalable.yaml",
"chars": 1138,
"preview": "# Config format schema number, the yaml support to valid case source from different dataset\ntime_info: &time_info\n num_"
},
{
"path": "data_preprocess.py",
"chars": 33206,
"preview": "import numpy as np\nimport pandas as pd\nimport os\nimport torch\nimport pickle\nfrom tqdm import tqdm\nfrom typing import Any"
},
{
"path": "environment.yml",
"chars": 2249,
"preview": "name: smart\nchannels:\n - pytorch\n - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free\n - https://mirrors.tuna.t"
},
{
"path": "pyproject.toml",
"chars": 950,
"preview": "[build-system]\nrequires = [\"setuptools>=42\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"smart\"\n"
},
{
"path": "requirements.txt",
"chars": 813,
"preview": "aiohappyeyeballs==2.4.3\naiohttp==3.10.10\naiosignal==1.3.1\nasync-timeout==4.0.3\nattrs==24.2.0\ncontourpy==1.3.0\ncycler==0."
},
{
"path": "scripts/install_pyg.sh",
"chars": 849,
"preview": "mkdir pyg_depend && cd pyg_depend\nwget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp"
},
{
"path": "scripts/traj_clstering.py",
"chars": 5870,
"preview": "from smart.utils.geometry import wrap_angle\nimport numpy as np\n\n\ndef average_distance_vectorized(point_set1, centroids):"
},
{
"path": "smart/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "smart/datamodules/__init__.py",
"chars": 66,
"preview": "from smart.datamodules.scalable_datamodule import MultiDataModule\n"
},
{
"path": "smart/datamodules/scalable_datamodule.py",
"chars": 4526,
"preview": "from typing import Optional\n\nimport pytorch_lightning as pl\nfrom torch_geometric.loader import DataLoader\nfrom smart.dat"
},
{
"path": "smart/datasets/__init__.py",
"chars": 57,
"preview": "from smart.datasets.scalable_dataset import MultiDataset\n"
},
{
"path": "smart/datasets/preprocess.py",
"chars": 24048,
"preview": "import torch\nimport numpy as np\nfrom scipy.interpolate import interp1d\nfrom scipy.spatial.distance import euclidean\nimpo"
},
{
"path": "smart/datasets/scalable_dataset.py",
"chars": 3406,
"preview": "import os\nimport pickle\nfrom typing import Callable, List, Optional, Tuple, Union\nimport pandas as pd\nfrom torch_geometr"
},
{
"path": "smart/layers/__init__.py",
"chars": 175,
"preview": "\nfrom smart.layers.attention_layer import AttentionLayer\nfrom smart.layers.fourier_embedding import FourierEmbedding, ML"
},
{
"path": "smart/layers/attention_layer.py",
"chars": 4503,
"preview": "\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.nn.conv import Messa"
},
{
"path": "smart/layers/fourier_embedding.py",
"chars": 3294,
"preview": "import math\nfrom typing import List, Optional\nimport torch\nimport torch.nn as nn\n\nfrom smart.utils import weight_init\n\n\n"
},
{
"path": "smart/layers/mlp_layer.py",
"chars": 603,
"preview": "\nimport torch\nimport torch.nn as nn\n\nfrom smart.utils import weight_init\n\n\nclass MLPLayer(nn.Module):\n\n def __init__("
},
{
"path": "smart/metrics/__init__.py",
"chars": 186,
"preview": "\nfrom smart.metrics.average_meter import AverageMeter\nfrom smart.metrics.min_ade import minADE\nfrom smart.metrics.min_fd"
},
{
"path": "smart/metrics/average_meter.py",
"chars": 521,
"preview": "\nimport torch\nfrom torchmetrics import Metric\n\n\nclass AverageMeter(Metric):\n\n def __init__(self, **kwargs) -> None:\n "
},
{
"path": "smart/metrics/min_ade.py",
"chars": 4174,
"preview": "\nfrom typing import Optional\n\nimport torch\nfrom torchmetrics import Metric\n\nfrom smart.metrics.utils import topk\nfrom sm"
},
{
"path": "smart/metrics/min_fde.py",
"chars": 2528,
"preview": "from typing import Optional\n\nimport torch\nfrom torchmetrics import Metric\n\nfrom smart.metrics.utils import topk\nfrom sma"
},
{
"path": "smart/metrics/next_token_cls.py",
"chars": 947,
"preview": "from typing import Optional\n\nimport torch\nfrom torchmetrics import Metric\n\nfrom smart.metrics.utils import topk\nfrom sma"
},
{
"path": "smart/metrics/utils.py",
"chars": 12650,
"preview": "from typing import Optional, Tuple\n\nimport torch\nfrom torch_scatter import gather_csr\nfrom torch_scatter import segment_"
},
{
"path": "smart/model/__init__.py",
"chars": 36,
"preview": "from smart.model.smart import SMART\n"
},
{
"path": "smart/model/smart.py",
"chars": 17228,
"preview": "import contextlib\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.data import Bat"
},
{
"path": "smart/modules/__init__.py",
"chars": 165,
"preview": "from smart.modules.smart_decoder import SMARTDecoder\nfrom smart.modules.map_decoder import SMARTMapDecoder\nfrom smart.mo"
},
{
"path": "smart/modules/agent_decoder.py",
"chars": 28824,
"preview": "import pickle\nfrom typing import Dict, Mapping, Optional\nimport torch\nimport torch.nn as nn\nfrom smart.layers import MLP"
},
{
"path": "smart/modules/map_decoder.py",
"chars": 6507,
"preview": "import os.path\nfrom typing import Dict\nimport torch\nimport torch.nn as nn\nfrom torch_cluster import radius_graph\nfrom to"
},
{
"path": "smart/modules/smart_decoder.py",
"chars": 2638,
"preview": "from typing import Dict, Optional\nimport torch\nimport torch.nn as nn\nfrom torch_geometric.data import HeteroData\nfrom sm"
},
{
"path": "smart/preprocess/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "smart/preprocess/preprocess.py",
"chars": 5138,
"preview": "import numpy as np\nimport pandas as pd\nimport os\nimport torch\nfrom typing import Any, Dict, List, Optional\n\npredict_unse"
},
{
"path": "smart/tokens/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "smart/transforms/__init__.py",
"chars": 63,
"preview": "from smart.transforms.target_builder import WaymoTargetBuilder\n"
},
{
"path": "smart/transforms/target_builder.py",
"chars": 6622,
"preview": "\nimport numpy as np\nimport torch\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.transforms import Base"
},
{
"path": "smart/utils/__init__.py",
"chars": 533,
"preview": "\nfrom smart.utils.geometry import angle_between_2d_vectors\nfrom smart.utils.geometry import angle_between_3d_vectors\nfro"
},
{
"path": "smart/utils/cluster_reader.py",
"chars": 1370,
"preview": "import io\nimport pickle\nimport pandas as pd\nimport json\n\n\nclass LoadScenarioFromCeph:\n def __init__(self):\n fr"
},
{
"path": "smart/utils/config.py",
"chars": 422,
"preview": "import os\nimport yaml\nimport easydict\n\n\ndef load_config_act(path):\n \"\"\" load config file\"\"\"\n with open(path, 'r') "
},
{
"path": "smart/utils/geometry.py",
"chars": 1210,
"preview": "\nimport math\n\nimport torch\n\n\ndef angle_between_2d_vectors(\n ctr_vector: torch.Tensor,\n nbr_vector: torch.T"
},
{
"path": "smart/utils/graph.py",
"chars": 3942,
"preview": "\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch_geometric.utils import coalesce\nfrom torch_geo"
},
{
"path": "smart/utils/list.py",
"chars": 188,
"preview": "\nfrom typing import Any, List, Optional\n\n\ndef safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:\n try:\n "
},
{
"path": "smart/utils/log.py",
"chars": 2043,
"preview": "import logging\nimport time\nimport os\n\n\nclass Logging:\n\n def make_log_dir(self, dirname='logs'):\n now_dir = os."
},
{
"path": "smart/utils/nan_checker.py",
"chars": 150,
"preview": "import torch\n\ndef check_nan_inf(t, s):\n assert not torch.isinf(t).any(), f\"{s} is inf, {t}\"\n assert not torch.isna"
},
{
"path": "smart/utils/weight_init.py",
"chars": 2826,
"preview": "\nimport torch.nn as nn\n\n\ndef weight_init(m: nn.Module) -> None:\n if isinstance(m, nn.Linear):\n nn.init.xavier_"
},
{
"path": "train.py",
"chars": 2539,
"preview": "\nfrom argparse import ArgumentParser\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import LearningRate"
},
{
"path": "val.py",
"chars": 2039,
"preview": "\nfrom argparse import ArgumentParser\nimport pytorch_lightning as pl\nfrom torch_geometric.loader import DataLoader\nfrom s"
}
]
// ... and 2 more files (download for full content)
About this extraction
This page contains the full source code of the rainmaker22/SMART GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 52 files (205.4 KB), approximately 54.1k tokens, and a symbol index with 167 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.