Repository: rainmaker22/SMART
Branch: main
Commit: 42e658542b03
Files: 52
Total size: 205.4 KB
Directory structure:
gitextract_98xwsw9y/
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── configs/
│ ├── train/
│ │ └── train_scalable.yaml
│ └── validation/
│ └── validation_scalable.yaml
├── data_preprocess.py
├── environment.yml
├── pyproject.toml
├── requirements.txt
├── scripts/
│ ├── install_pyg.sh
│ └── traj_clstering.py
├── smart/
│ ├── __init__.py
│ ├── datamodules/
│ │ ├── __init__.py
│ │ └── scalable_datamodule.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── preprocess.py
│ │ └── scalable_dataset.py
│ ├── layers/
│ │ ├── __init__.py
│ │ ├── attention_layer.py
│ │ ├── fourier_embedding.py
│ │ └── mlp_layer.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── average_meter.py
│ │ ├── min_ade.py
│ │ ├── min_fde.py
│ │ ├── next_token_cls.py
│ │ └── utils.py
│ ├── model/
│ │ ├── __init__.py
│ │ └── smart.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── agent_decoder.py
│ │ ├── map_decoder.py
│ │ └── smart_decoder.py
│ ├── preprocess/
│ │ ├── __init__.py
│ │ └── preprocess.py
│ ├── tokens/
│ │ ├── __init__.py
│ │ ├── cluster_frame_5_2048.pkl
│ │ └── map_traj_token5.pkl
│ ├── transforms/
│ │ ├── __init__.py
│ │ └── target_builder.py
│ └── utils/
│ ├── __init__.py
│ ├── cluster_reader.py
│ ├── config.py
│ ├── geometry.py
│ ├── graph.py
│ ├── list.py
│ ├── log.py
│ ├── nan_checker.py
│ └── weight_init.py
├── train.py
└── val.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.github
ckpt/
# assets/
# C extensions
*.so
# /assets
/data
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
*.jpg
env/
venv/
ENV/
env.bak/
venv.bak/
*.jpg
pyg_depend/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# IDEs
.idea
.vscode
# seed project
av2/
lightning_logs/
lightning_logs_/
lightning_l/
.DS_Store
data/argo
data/res
data/waymo*
fig*/
data/waymo_token
data/submission
data/token_seq_emb_nuplan
data/token_seq_emb_waymo
data/nuplan*
submission.tar.gz
data/feat*
data/scalable
data/pos_data
res_metrics*
gathered*
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction
[Paper](https://arxiv.org/abs/2405.15677) | [Webpage](https://smart-motion.github.io/smart/)
- **Ranked 1st** on the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/)
- **Champion** of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)
## News
- **[December 31, 2024]** SMART-Planner achieved state-of-the-art performance on **nuPlan closed-loop planning**
- **[September 26, 2024]** SMART was **accepted to** NeurIPS 2024
- **[August 31, 2024]** Code released
- **[May 24, 2024]** SMART won the championship of the [Waymo Open Sim Agents Challenge 2024](https://waymo.com/open/challenges/2024/sim-agents/) at the [CVPR 2024 Workshop on Autonomous Driving (WAD)](https://cvpr2024.wad.vision/)
- **[May 24, 2024]** SMART paper released on [arxiv](https://arxiv.org/abs/2405.15677)
## Introduction
This repository contains the official implementation of SMART: Scalable Multi-agent Real-time Motion Generation via Next-token Prediction. SMART is a novel autonomous driving motion generation paradigm that models vectorized map and agent trajectory data into discrete sequence tokens.
https://github.com/user-attachments/assets/74a61627-8444-4e54-bb10-d317dd2aacd9
## Requirements
To set up the environment, you can use conda to create and activate a new environment with the necessary dependencies:
```bash
conda env create -f environment.yml
conda activate SMART
pip install -r requirements.txt
```
If you encounter issues while installing pyg dependencies, execute the following script:
```setup
bash install_pyg.sh
```
Alternatively, you can configure the environment in your preferred way. Installing the latest versions of PyTorch, PyG, and PyTorch Lightning should suffice.
## Data installation
**Step 1: Download the Dataset**
Download the Waymo Open Motion Dataset (`scenario protocol` format) and organize the data as follows:
```
SMART
├── data
│ ├── waymo
│ │ ├── scenario
│ │ │ ├──training
│ │ │ ├──validation
│ │ │ ├──testing
├── model
├── tools
```
**Step 2: Install the Waymo Open Dataset API**
Follow the instructions [here](https://github.com/waymo-research/waymo-open-dataset) to install the Waymo Open Dataset API.
**Step 3: Preprocess the Dataset**
Preprocess the dataset by running:
```
python data_preprocess.py --input_dir ./data/waymo/scenario/training --output_dir ./data/waymo_processed/training
```
The first path is the raw data path, and the second is the output data path.
The processed data will be saved to the `data/waymo_processed/` directory as follows:
```
SMART
├── data
│ ├── waymo_processed
│ │ ├── training
│ │ ├── validation
│ │ ├──testing
├── model
├── utils
```
## Training
To train the model, run the following command:
```train
python train.py --config ${config_path}
```
The default config path is `configs/train/train_scalable.yaml`. Ensure you have downloaded and prepared the Waymo data for training.
## Evaluation
To evaluate the model, run:
```eval
python eval.py --config ${config_path} --pretrain_ckpt ${ckpt_path}
```
This will evaluate the model using the configuration and checkpoint provided.
## Pre-trained Models
To comply with the WOMD participation agreement, we will release the model parameters of a medium-sized model not trained on Waymo data. Users can fine-tune this model with Waymo data as needed.
## Results
### Waymo Open Motion Dataset Sim Agents Challenge
Our model achieves the following performance on the [Waymo Open Motion Dataset Sim Agents Challenge](https://waymo.com/open/challenges/2024/sim-agents/):
| Model name | Metric Score |
| :-----------: | ------------ |
| SMART-tiny | 0.7591 |
| SMART-large | 0.7614 |
| SMART-zeroshot| 0.7210 |
### NuPlan Closed-loop Planning
**SMART-Planner** achieved state-of-the-art performance among learning-based algorithms on **nuPlan closed-loop planning**. The results on val14 are shown below:

## Citation
If you find this repository useful, please consider citing our work and giving us a star:
```citation
@article{wu2024smart,
title={SMART: Scalable Multi-agent Real-time Simulation via Next-token Prediction},
author={Wu, Wei and Feng, Xiaoxin and Gao, Ziyan and Kan, Yuheng},
journal={arXiv preprint arXiv:2405.15677},
year={2024}
}
```
## Acknowledgements
Special thanks to the [QCNET](https://github.com/ZikangZhou/QCNet) repository for providing valuable reference code that significantly influenced this work.
## License
All code in this repository is licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
================================================
FILE: __init__.py
================================================
================================================
FILE: configs/train/train_scalable.yaml
================================================
# Config format schema number, the yaml support to valid case source from different dataset
time_info: &time_info
num_historical_steps: 11
num_future_steps: 80
use_intention: True
token_size: 2048
Dataset:
root:
train_batch_size: 1
val_batch_size: 1
test_batch_size: 1
shuffle: True
num_workers: 1
pin_memory: True
persistent_workers: True
train_raw_dir: ["data/valid_demo"]
val_raw_dir: ["data/valid_demo"]
test_raw_dir:
transform: WaymoTargetBuilder
train_processed_dir:
val_processed_dir:
test_processed_dir:
dataset: "scalable"
<<: *time_info
Trainer:
strategy: ddp_find_unused_parameters_false
accelerator: "gpu"
devices: 1
max_epochs: 32
save_ckpt_path:
num_nodes: 1
mode:
ckpt_path:
precision: 32
accumulate_grad_batches: 1
Model:
mode: "train"
predictor: "smart"
dataset: "waymo"
input_dim: 2
hidden_dim: 128
output_dim: 2
output_head: False
num_heads: 8
<<: *time_info
head_dim: 16
dropout: 0.1
num_freq_bands: 64
lr: 0.0005
warmup_steps: 0
total_steps: 32
decoder:
<<: *time_info
num_map_layers: 3
num_agent_layers: 6
a2a_radius: 60
pl2pl_radius: 10
pl2a_radius: 30
time_span: 30
================================================
FILE: configs/validation/validation_scalable.yaml
================================================
# Config format schema number, the yaml support to valid case source from different dataset
time_info: &time_info
num_historical_steps: 11
num_future_steps: 80
token_size: 2048
Dataset:
root:
batch_size: 1
shuffle: True
num_workers: 1
pin_memory: True
persistent_workers: True
train_raw_dir:
val_raw_dir: ["data/valid_demo"]
test_raw_dir:
TargetBuilder: WaymoTargetBuilder
train_processed_dir:
val_processed_dir:
test_processed_dir:
dataset: "scalable"
<<: *time_info
Trainer:
strategy: ddp_find_unused_parameters_false
accelerator: "gpu"
devices: 1
max_epochs: 32
save_ckpt_path:
num_nodes: 1
mode:
ckpt_path:
precision: 32
accumulate_grad_batches: 1
Model:
mode: "validation"
predictor: "smart"
dataset: "waymo"
input_dim: 2
hidden_dim: 128
output_dim: 2
output_head: False
num_heads: 8
<<: *time_info
head_dim: 16
dropout: 0.1
num_freq_bands: 64
lr: 0.0005
warmup_steps: 0
total_steps: 32
decoder:
<<: *time_info
num_map_layers: 3
num_agent_layers: 6
a2a_radius: 60
pl2pl_radius: 10
pl2a_radius: 30
time_span: 30
================================================
FILE: data_preprocess.py
================================================
import numpy as np
import pandas as pd
import os
import torch
import pickle
from tqdm import tqdm
from typing import Any, Dict, List, Optional
import easydict
predict_unseen_agents = False
vector_repr = True
root = ''
split = 'train'
raw_dir = os.path.join(root, split, 'raw')
_raw_dir = raw_dir
if os.path.isdir(_raw_dir):
_raw_file_names = [name for name in os.listdir(_raw_dir)]
else:
_raw_file_names = []
processed_dir = os.path.join(root, split, 'processed')
_processed_dir = processed_dir
if os.path.isdir(_processed_dir):
_processed_file_names = [name for name in os.listdir(_processed_dir) if
name.endswith(('pkl', 'pickle'))]
else:
_processed_file_names = []
_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']
_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
_point_sides = ['LEFT', 'RIGHT', 'CENTER']
_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
_polygon_is_intersections = [True, False, None]
Lane_type_hash = {
4: "BIKE",
3: "VEHICLE",
2: "VEHICLE",
1: "BUS"
}
boundary_type_hash = {
5: "UNKNOWN",
6: "DASHED_WHITE",
7: "SOLID_WHITE",
8: "DOUBLE_DASH_WHITE",
9: "DASHED_YELLOW",
10: "DOUBLE_DASH_YELLOW",
11: "SOLID_YELLOW",
12: "DOUBLE_SOLID_YELLOW",
13: "DASH_SOLID_YELLOW",
14: "UNKNOWN",
15: "EDGE",
16: "EDGE"
}
def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
try:
return ls.index(elem)
except ValueError:
return None
def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:
if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps
historical_df = df[df['timestep'] == num_historical_steps-1]
agent_ids = list(historical_df['track_id'].unique())
df = df[df['track_id'].isin(agent_ids)]
else:
agent_ids = list(df['track_id'].unique())
num_agents = len(agent_ids)
# initialization
valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
agent_id: List[Optional[str]] = [None] * num_agents
agent_type = torch.zeros(num_agents, dtype=torch.uint8)
agent_category = torch.zeros(num_agents, dtype=torch.uint8)
position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
for track_id, track_df in df.groupby('track_id'):
agent_idx = agent_ids.index(track_id)
agent_steps = track_df['timestep'].values
valid_mask[agent_idx, agent_steps] = True
current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]
predict_mask[agent_idx, agent_steps] = True
if vector_repr: # a time step t is valid only when both t and t-1 are valid
valid_mask[agent_idx, 1: num_historical_steps] = (
valid_mask[agent_idx, :num_historical_steps - 1] &
valid_mask[agent_idx, 1: num_historical_steps])
valid_mask[agent_idx, 0] = False
predict_mask[agent_idx, :num_historical_steps] = False
if not current_valid_mask[agent_idx]:
predict_mask[agent_idx, num_historical_steps:] = False
agent_id[agent_idx] = track_id
agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
agent_category[agent_idx] = track_df['object_category'].values[0]
position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,
track_df['position_y'].values,
track_df['position_z'].values],
axis=-1)).float()
heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()
velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,
track_df['velocity_y'].values],
axis=-1)).float()
shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,
track_df['width'].values,
track_df["height"].values],
axis=-1)).float()
av_idx = agent_id.index(av_id)
if split == 'test':
predict_mask[current_valid_mask
| (agent_category == 2)
| (agent_category == 3), num_historical_steps:] = True
return {
'num_nodes': num_agents,
'av_index': av_idx,
'valid_mask': valid_mask,
'predict_mask': predict_mask,
'id': agent_id,
'type': agent_type,
'category': agent_category,
'position': position,
'heading': heading,
'velocity': velocity,
'shape': shape
}
def get_map_features(map_infos, tf_current_light, dim=3):
lane_segments = map_infos['lane']
all_polylines = map_infos["all_polylines"]
crosswalks = map_infos['crosswalk']
road_edges = map_infos['road_edge']
road_lines = map_infos['road_line']
lane_segment_ids = [info["id"] for info in lane_segments]
cross_walk_ids = [info["id"] for info in crosswalks]
road_edge_ids = [info["id"] for info in road_edges]
road_line_ids = [info["id"] for info in road_lines]
polygon_ids = lane_segment_ids + road_edge_ids + road_line_ids + cross_walk_ids
num_polygons = len(lane_segment_ids) + len(road_edge_ids) + len(road_line_ids) + len(cross_walk_ids)
# initialization
polygon_type = torch.zeros(num_polygons, dtype=torch.uint8)
polygon_light_type = torch.ones(num_polygons, dtype=torch.uint8) * 3
point_position: List[Optional[torch.Tensor]] = [None] * num_polygons
point_orientation: List[Optional[torch.Tensor]] = [None] * num_polygons
point_magnitude: List[Optional[torch.Tensor]] = [None] * num_polygons
point_height: List[Optional[torch.Tensor]] = [None] * num_polygons
point_type: List[Optional[torch.Tensor]] = [None] * num_polygons
for lane_segment in lane_segments:
lane_segment = easydict.EasyDict(lane_segment)
lane_segment_idx = polygon_ids.index(lane_segment.id)
polyline_index = lane_segment.polyline_index
centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
centerline = torch.from_numpy(centerline).float()
polygon_type[lane_segment_idx] = _polygon_types.index(Lane_type_hash[lane_segment.type])
res = tf_current_light[tf_current_light["lane_id"] == str(lane_segment.id)]
if len(res) != 0:
polygon_light_type[lane_segment_idx] = _polygon_light_type.index(res["state"].item())
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
center_vectors = centerline[1:] - centerline[:-1]
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
center_type = _point_types.index('CENTERLINE')
point_type[lane_segment_idx] = torch.cat(
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
for lane_segment in road_edges:
lane_segment = easydict.EasyDict(lane_segment)
lane_segment_idx = polygon_ids.index(lane_segment.id)
polyline_index = lane_segment.polyline_index
centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
centerline = torch.from_numpy(centerline).float()
polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
center_vectors = centerline[1:] - centerline[:-1]
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
center_type = _point_types.index('EDGE')
point_type[lane_segment_idx] = torch.cat(
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
for lane_segment in road_lines:
lane_segment = easydict.EasyDict(lane_segment)
lane_segment_idx = polygon_ids.index(lane_segment.id)
polyline_index = lane_segment.polyline_index
centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
centerline = torch.from_numpy(centerline).float()
polygon_type[lane_segment_idx] = _polygon_types.index("VEHICLE")
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
center_vectors = centerline[1:] - centerline[:-1]
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
center_type = _point_types.index(boundary_type_hash[lane_segment.type])
point_type[lane_segment_idx] = torch.cat(
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
for crosswalk in crosswalks:
crosswalk = easydict.EasyDict(crosswalk)
lane_segment_idx = polygon_ids.index(crosswalk.id)
polyline_index = crosswalk.polyline_index
centerline = all_polylines[polyline_index[0]:polyline_index[1], :]
centerline = torch.from_numpy(centerline).float()
polygon_type[lane_segment_idx] = _polygon_types.index("PEDESTRIAN")
point_position[lane_segment_idx] = torch.cat([centerline[:-1, :dim]], dim=0)
center_vectors = centerline[1:] - centerline[:-1]
point_orientation[lane_segment_idx] = torch.cat([torch.atan2(center_vectors[:, 1], center_vectors[:, 0])], dim=0)
point_magnitude[lane_segment_idx] = torch.norm(torch.cat([center_vectors[:, :2]], dim=0), p=2, dim=-1)
point_height[lane_segment_idx] = torch.cat([center_vectors[:, 2]], dim=0)
center_type = _point_types.index("CROSSWALK")
point_type[lane_segment_idx] = torch.cat(
[torch.full((len(center_vectors),), center_type, dtype=torch.uint8)], dim=0)
num_points = torch.tensor([point.size(0) for point in point_position], dtype=torch.long)
point_to_polygon_edge_index = torch.stack(
[torch.arange(num_points.sum(), dtype=torch.long),
torch.arange(num_polygons, dtype=torch.long).repeat_interleave(num_points)], dim=0)
polygon_to_polygon_edge_index = []
polygon_to_polygon_type = []
for lane_segment in lane_segments:
lane_segment = easydict.EasyDict(lane_segment)
lane_segment_idx = polygon_ids.index(lane_segment.id)
pred_inds = []
for pred in lane_segment.entry_lanes:
pred_idx = safe_list_index(polygon_ids, pred)
if pred_idx is not None:
pred_inds.append(pred_idx)
if len(pred_inds) != 0:
polygon_to_polygon_edge_index.append(
torch.stack([torch.tensor(pred_inds, dtype=torch.long),
torch.full((len(pred_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
polygon_to_polygon_type.append(
torch.full((len(pred_inds),), _polygon_to_polygon_types.index('PRED'), dtype=torch.uint8))
succ_inds = []
for succ in lane_segment.exit_lanes:
succ_idx = safe_list_index(polygon_ids, succ)
if succ_idx is not None:
succ_inds.append(succ_idx)
if len(succ_inds) != 0:
polygon_to_polygon_edge_index.append(
torch.stack([torch.tensor(succ_inds, dtype=torch.long),
torch.full((len(succ_inds),), lane_segment_idx, dtype=torch.long)], dim=0))
polygon_to_polygon_type.append(
torch.full((len(succ_inds),), _polygon_to_polygon_types.index('SUCC'), dtype=torch.uint8))
if len(lane_segment.left_neighbors) != 0:
left_neighbor_ids = lane_segment.left_neighbors
for left_neighbor_id in left_neighbor_ids:
left_idx = safe_list_index(polygon_ids, left_neighbor_id)
if left_idx is not None:
polygon_to_polygon_edge_index.append(
torch.tensor([[left_idx], [lane_segment_idx]], dtype=torch.long))
polygon_to_polygon_type.append(
torch.tensor([_polygon_to_polygon_types.index('LEFT')], dtype=torch.uint8))
if len(lane_segment.right_neighbors) != 0:
right_neighbor_ids = lane_segment.right_neighbors
for right_neighbor_id in right_neighbor_ids:
right_idx = safe_list_index(polygon_ids, right_neighbor_id)
if right_idx is not None:
polygon_to_polygon_edge_index.append(
torch.tensor([[right_idx], [lane_segment_idx]], dtype=torch.long))
polygon_to_polygon_type.append(
torch.tensor([_polygon_to_polygon_types.index('RIGHT')], dtype=torch.uint8))
if len(polygon_to_polygon_edge_index) != 0:
polygon_to_polygon_edge_index = torch.cat(polygon_to_polygon_edge_index, dim=1)
polygon_to_polygon_type = torch.cat(polygon_to_polygon_type, dim=0)
else:
polygon_to_polygon_edge_index = torch.tensor([[], []], dtype=torch.long)
polygon_to_polygon_type = torch.tensor([], dtype=torch.uint8)
map_data = {
'map_polygon': {},
'map_point': {},
('map_point', 'to', 'map_polygon'): {},
('map_polygon', 'to', 'map_polygon'): {},
}
map_data['map_polygon']['num_nodes'] = num_polygons
map_data['map_polygon']['type'] = polygon_type
map_data['map_polygon']['light_type'] = polygon_light_type
if len(num_points) == 0:
map_data['map_point']['num_nodes'] = 0
map_data['map_point']['position'] = torch.tensor([], dtype=torch.float)
map_data['map_point']['orientation'] = torch.tensor([], dtype=torch.float)
map_data['map_point']['magnitude'] = torch.tensor([], dtype=torch.float)
if dim == 3:
map_data['map_point']['height'] = torch.tensor([], dtype=torch.float)
map_data['map_point']['type'] = torch.tensor([], dtype=torch.uint8)
map_data['map_point']['side'] = torch.tensor([], dtype=torch.uint8)
else:
map_data['map_point']['num_nodes'] = num_points.sum().item()
map_data['map_point']['position'] = torch.cat(point_position, dim=0)
map_data['map_point']['orientation'] = torch.cat(point_orientation, dim=0)
map_data['map_point']['magnitude'] = torch.cat(point_magnitude, dim=0)
if dim == 3:
map_data['map_point']['height'] = torch.cat(point_height, dim=0)
map_data['map_point']['type'] = torch.cat(point_type, dim=0)
map_data['map_point', 'to', 'map_polygon']['edge_index'] = point_to_polygon_edge_index
map_data['map_polygon', 'to', 'map_polygon']['edge_index'] = polygon_to_polygon_edge_index
map_data['map_polygon', 'to', 'map_polygon']['type'] = polygon_to_polygon_type
# import matplotlib.pyplot as plt
# plt.axis('equal')
# plt.scatter(map_data['map_point']['position'][:, 0],
# map_data['map_point']['position'][:, 1], s=0.2, c='black', edgecolors='none')
# plt.show(dpi=600)
return map_data
def process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, start_timestamp, end_timestamp):
agents_array = track_info["trajs"].transpose(1, 0, 2)
object_id = np.array(track_info["object_id"])
object_type = track_info["object_type"]
id_hash = {object_id[o_idx]: object_type[o_idx] for o_idx in range(len(object_id))}
def type_hash(x):
tp = id_hash[x]
type_re_hash = {
"TYPE_VEHICLE": "vehicle",
"TYPE_PEDESTRIAN": "pedestrian",
"TYPE_CYCLIST": "cyclist",
"TYPE_OTHER": "background",
"TYPE_UNSET": "background"
}
return type_re_hash[tp]
columns = ['observed', 'track_id', 'object_type', 'object_category', 'timestep',
'position_x', 'position_y', 'position_z', 'length', 'width', 'height', 'heading', 'velocity_x', 'velocity_y',
'scenario_id', 'start_timestamp', 'end_timestamp', 'num_timestamps',
'focal_track_id', 'city']
new_columns = np.ones((agents_array.shape[0], agents_array.shape[1], 11))
new_columns[:11, :, 0] = True
new_columns[11:, :, 0] = False
for index in range(new_columns.shape[0]):
new_columns[index, :, 4] = int(index)
new_columns[..., 1] = object_id
new_columns[..., 2] = object_id
new_columns[:, tracks_to_predict["track_index"], 3] = 3
new_columns[..., 5] = 11
new_columns[..., 6] = int(start_timestamp)
new_columns[..., 7] = int(end_timestamp)
new_columns[..., 8] = int(91)
new_columns[..., 9] = object_id
new_columns[..., 10] = 10086
new_columns = new_columns
new_agents_array = np.concatenate([new_columns, agents_array], axis=-1)
new_agents_array = new_agents_array[new_agents_array[..., -1] == 1.0].reshape(-1, new_agents_array.shape[-1])
new_agents_array = new_agents_array[..., [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16, 17, 18, 19, 5, 6, 7, 8, 9, 10]]
new_agents_array = pd.DataFrame(data=new_agents_array, columns=columns)
new_agents_array["object_type"] = new_agents_array["object_type"].apply(func=type_hash)
new_agents_array["start_timestamp"] = new_agents_array["start_timestamp"].astype(int)
new_agents_array["end_timestamp"] = new_agents_array["end_timestamp"].astype(int)
new_agents_array["num_timestamps"] = new_agents_array["num_timestamps"].astype(int)
new_agents_array["scenario_id"] = scenario_id
return new_agents_array
def process_dynamic_map(dynamic_map_infos):
lane_ids = dynamic_map_infos["lane_id"]
tf_lights = []
for t in range(len(lane_ids)):
lane_id = lane_ids[t]
time = np.ones_like(lane_id) * t
state = dynamic_map_infos["state"][t]
tf_light = np.concatenate([lane_id, time, state], axis=0)
tf_lights.append(tf_light)
tf_lights = np.concatenate(tf_lights, axis=1).transpose(1, 0)
tf_lights = pd.DataFrame(data=tf_lights, columns=["lane_id", "time_step", "state"])
tf_lights["time_step"] = tf_lights["time_step"].astype("str")
tf_lights["lane_id"] = tf_lights["lane_id"].astype("str")
tf_lights["state"] = tf_lights["state"].astype("str")
tf_lights.loc[tf_lights["state"].str.contains("STOP"), ["state"] ] = 'LANE_STATE_STOP'
tf_lights.loc[tf_lights["state"].str.contains("GO"), ["state"] ] = 'LANE_STATE_GO'
tf_lights.loc[tf_lights["state"].str.contains("CAUTION"), ["state"] ] = 'LANE_STATE_CAUTION'
return tf_lights
polyline_type = {
# for lane
'TYPE_UNDEFINED': -1,
'TYPE_FREEWAY': 1,
'TYPE_SURFACE_STREET': 2,
'TYPE_BIKE_LANE': 3,
# for roadline
'TYPE_UNKNOWN': -1,
'TYPE_BROKEN_SINGLE_WHITE': 6,
'TYPE_SOLID_SINGLE_WHITE': 7,
'TYPE_SOLID_DOUBLE_WHITE': 8,
'TYPE_BROKEN_SINGLE_YELLOW': 9,
'TYPE_BROKEN_DOUBLE_YELLOW': 10,
'TYPE_SOLID_SINGLE_YELLOW': 11,
'TYPE_SOLID_DOUBLE_YELLOW': 12,
'TYPE_PASSING_DOUBLE_YELLOW': 13,
# for roadedge
'TYPE_ROAD_EDGE_BOUNDARY': 15,
'TYPE_ROAD_EDGE_MEDIAN': 16,
# for stopsign
'TYPE_STOP_SIGN': 17,
# for crosswalk
'TYPE_CROSSWALK': 18,
# for speed bump
'TYPE_SPEED_BUMP': 19
}
object_type = {
0: 'TYPE_UNSET',
1: 'TYPE_VEHICLE',
2: 'TYPE_PEDESTRIAN',
3: 'TYPE_CYCLIST',
4: 'TYPE_OTHER'
}
signal_state = {
0: 'LANE_STATE_UNKNOWN',
# // States for traffic signals with arrows.
1: 'LANE_STATE_ARROW_STOP',
2: 'LANE_STATE_ARROW_CAUTION',
3: 'LANE_STATE_ARROW_GO',
# // Standard round traffic signals.
4: 'LANE_STATE_STOP',
5: 'LANE_STATE_CAUTION',
6: 'LANE_STATE_GO',
# // Flashing light signals.
7: 'LANE_STATE_FLASHING_STOP',
8: 'LANE_STATE_FLASHING_CAUTION'
}
signal_state_to_id = {}
for key, val in signal_state.items():
signal_state_to_id[val] = key
def decode_tracks_from_proto(tracks):
track_infos = {
'object_id': [], # {0: unset, 1: vehicle, 2: pedestrian, 3: cyclist, 4: others}
'object_type': [],
'trajs': []
}
for cur_data in tracks: # number of objects
cur_traj = [np.array([x.center_x, x.center_y, x.center_z, x.length, x.width, x.height, x.heading,
x.velocity_x, x.velocity_y, x.valid], dtype=np.float32) for x in cur_data.states]
cur_traj = np.stack(cur_traj, axis=0) # (num_timestamp, 10)
track_infos['object_id'].append(cur_data.id)
track_infos['object_type'].append(object_type[cur_data.object_type])
track_infos['trajs'].append(cur_traj)
track_infos['trajs'] = np.stack(track_infos['trajs'], axis=0) # (num_objects, num_timestamp, 9)
return track_infos
from collections import defaultdict
def decode_map_features_from_proto(map_features):
map_infos = {
'lane': [],
'road_line': [],
'road_edge': [],
'stop_sign': [],
'crosswalk': [],
'speed_bump': [],
'lane_dict': {},
'lane2other_dict': {}
}
polylines = []
point_cnt = 0
lane2other_dict = defaultdict(list)
for cur_data in map_features:
cur_info = {'id': cur_data.id}
if cur_data.lane.ByteSize() > 0:
cur_info['speed_limit_mph'] = cur_data.lane.speed_limit_mph
cur_info['type'] = cur_data.lane.type + 1 # 0: undefined, 1: freeway, 2: surface_street, 3: bike_lane
cur_info['left_neighbors'] = [lane.feature_id for lane in cur_data.lane.left_neighbors]
cur_info['right_neighbors'] = [lane.feature_id for lane in cur_data.lane.right_neighbors]
cur_info['interpolating'] = cur_data.lane.interpolating
cur_info['entry_lanes'] = list(cur_data.lane.entry_lanes)
cur_info['exit_lanes'] = list(cur_data.lane.exit_lanes)
cur_info['left_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.left_boundaries]
cur_info['right_boundary_type'] = [x.boundary_type + 5 for x in cur_data.lane.right_boundaries]
cur_info['left_boundary'] = [x.boundary_feature_id for x in cur_data.lane.left_boundaries]
cur_info['right_boundary'] = [x.boundary_feature_id for x in cur_data.lane.right_boundaries]
cur_info['left_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.left_boundaries]
cur_info['left_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.left_boundaries]
cur_info['right_boundary_start_index'] = [lane.lane_start_index for lane in cur_data.lane.right_boundaries]
cur_info['right_boundary_end_index'] = [lane.lane_end_index for lane in cur_data.lane.right_boundaries]
lane2other_dict[cur_data.id].extend(cur_info['left_boundary'])
lane2other_dict[cur_data.id].extend(cur_info['right_boundary'])
global_type = cur_info['type']
cur_polyline = np.stack(
[np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in cur_data.lane.polyline],
axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['lane'].append(cur_info)
map_infos['lane_dict'][cur_data.id] = cur_info
elif cur_data.road_line.ByteSize() > 0:
cur_info['type'] = cur_data.road_line.type + 5
global_type = cur_info['type']
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
cur_data.road_line.polyline], axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['road_line'].append(cur_info)
elif cur_data.road_edge.ByteSize() > 0:
cur_info['type'] = cur_data.road_edge.type + 14
global_type = cur_info['type']
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
cur_data.road_edge.polyline], axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['road_edge'].append(cur_info)
elif cur_data.stop_sign.ByteSize() > 0:
cur_info['lane_ids'] = list(cur_data.stop_sign.lane)
for i in cur_info['lane_ids']:
lane2other_dict[i].append(cur_data.id)
point = cur_data.stop_sign.position
cur_info['position'] = np.array([point.x, point.y, point.z])
global_type = polyline_type['TYPE_STOP_SIGN']
cur_polyline = np.array([point.x, point.y, point.z, global_type, cur_data.id]).reshape(1, 5)
if cur_polyline.shape[0] <= 1:
continue
map_infos['stop_sign'].append(cur_info)
elif cur_data.crosswalk.ByteSize() > 0:
global_type = polyline_type['TYPE_CROSSWALK']
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
cur_data.crosswalk.polygon], axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['crosswalk'].append(cur_info)
elif cur_data.speed_bump.ByteSize() > 0:
global_type = polyline_type['TYPE_SPEED_BUMP']
cur_polyline = np.stack([np.array([point.x, point.y, point.z, global_type, cur_data.id]) for point in
cur_data.speed_bump.polygon], axis=0)
cur_polyline = np.concatenate((cur_polyline[:, 0:3], cur_polyline[:, 3:]), axis=-1)
if cur_polyline.shape[0] <= 1:
continue
map_infos['speed_bump'].append(cur_info)
else:
# print(cur_data)
continue
polylines.append(cur_polyline)
cur_info['polyline_index'] = (point_cnt, point_cnt + len(cur_polyline))
point_cnt += len(cur_polyline)
# try:
polylines = np.concatenate(polylines, axis=0).astype(np.float32)
# except:
# polylines = np.zeros((0, 8), dtype=np.float32)
# print('Empty polylines: ')
map_infos['all_polylines'] = polylines
map_infos['lane2other_dict'] = lane2other_dict
return map_infos
def decode_dynamic_map_states_from_proto(dynamic_map_states):
dynamic_map_infos = {
'lane_id': [],
'state': [],
'stop_point': []
}
for cur_data in dynamic_map_states: # (num_timestamp)
lane_id, state, stop_point = [], [], []
for cur_signal in cur_data.lane_states: # (num_observed_signals)
lane_id.append(cur_signal.lane)
state.append(signal_state[cur_signal.state])
stop_point.append([cur_signal.stop_point.x, cur_signal.stop_point.y, cur_signal.stop_point.z])
dynamic_map_infos['lane_id'].append(np.array([lane_id]))
dynamic_map_infos['state'].append(np.array([state]))
dynamic_map_infos['stop_point'].append(np.array([stop_point]))
return dynamic_map_infos
def process_single_data(scenario):
info = {}
info['scenario_id'] = scenario.scenario_id
info['timestamps_seconds'] = list(scenario.timestamps_seconds) # list of int of shape (91)
info['current_time_index'] = scenario.current_time_index # int, 10
info['sdc_track_index'] = scenario.sdc_track_index # int
info['objects_of_interest'] = list(scenario.objects_of_interest) # list, could be empty list
info['tracks_to_predict'] = {
'track_index': [cur_pred.track_index for cur_pred in scenario.tracks_to_predict],
'difficulty': [cur_pred.difficulty for cur_pred in scenario.tracks_to_predict]
} # for training: suggestion of objects to train on, for val/test: need to be predicted
track_infos = decode_tracks_from_proto(scenario.tracks)
info['tracks_to_predict']['object_type'] = [track_infos['object_type'][cur_idx] for cur_idx in
info['tracks_to_predict']['track_index']]
# decode map related data
map_infos = decode_map_features_from_proto(scenario.map_features)
dynamic_map_infos = decode_dynamic_map_states_from_proto(scenario.dynamic_map_states)
save_infos = {
'track_infos': track_infos,
'dynamic_map_infos': dynamic_map_infos,
'map_infos': map_infos
}
save_infos.update(info)
return save_infos
import tensorflow as tf
from waymo_open_dataset.protos import scenario_pb2
def wm2argo(file, dir_name, output_dir):
file_path = os.path.join(dir_name, file)
dataset = tf.data.TFRecordDataset(file_path, compression_type='', num_parallel_reads=3)
for cnt, data in enumerate(dataset):
print(cnt)
scenario = scenario_pb2.Scenario()
scenario.ParseFromString(bytearray(data.numpy()))
save_infos = process_single_data(scenario) # pkl2mtr
map_info = save_infos["map_infos"]
track_info = save_infos['track_infos']
scenario_id = save_infos['scenario_id']
tracks_to_predict = save_infos['tracks_to_predict']
sdc_track_index = save_infos['sdc_track_index']
av_id = track_info["object_id"][sdc_track_index]
if len(tracks_to_predict["track_index"]) < 1:
return
dynamic_map_infos = save_infos["dynamic_map_infos"]
tf_lights = process_dynamic_map(dynamic_map_infos)
tf_current_light = tf_lights.loc[tf_lights["time_step"] == "11"]
map_data = get_map_features(map_info, tf_current_light)
new_agents_array = process_agent(track_info, tracks_to_predict, sdc_track_index, scenario_id, 0, 91) # mtr2argo
data = dict()
data['scenario_id'] = new_agents_array['scenario_id'].values[0]
data['city'] = new_agents_array['city'].values[0]
data['agent'] = get_agent_features(new_agents_array, av_id, num_historical_steps=11)
data.update(map_data)
with open(os.path.join(output_dir, scenario_id + '.pkl'), "wb+") as f:
pickle.dump(data, f)
def batch_process9s_transformer(dir_name, output_dir, num_workers=2):
from functools import partial
import multiprocessing
packages = os.listdir(dir_name)
func = partial(
wm2argo, output_dir=output_dir, dir_name=dir_name)
with multiprocessing.Pool(num_workers) as p:
list(tqdm(p.imap(func, packages), total=len(packages)))
from argparse import ArgumentParser
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--input_dir', type=str, default='data/waymo/scenario/training')
parser.add_argument('--output_dir', type=str, default='data/waymo_processed/training')
args = parser.parse_args()
files = os.listdir(args.input_dir)
for file in tqdm(files):
wm2argo(file, args.input_dir, args.output_dir)
# batch_process9s_transformer(args.input_dir, args.output_dir, num_workers="ur_cpu_count")
================================================
FILE: environment.yml
================================================
name: smart
channels:
- pytorch
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotli-python=1.0.9=py39h6a678d5_8
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.9.24=h06a4308_0
- certifi=2024.8.30=py39h06a4308_0
- charset-normalizer=3.3.2=pyhd3eb1b0_0
- cudatoolkit=11.3.1=h2bc3f7f_2
- ffmpeg=4.3=hf484d3e_0
- freetype=2.12.1=h4a9f257_0
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- idna=3.7=py39h06a4308_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- jpeg=9e=h5eee18b_3
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.40=h12ee557_0
- lerc=3.0=h295c915_0
- libdeflate=1.17=h5eee18b_1
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.14=0
- libidn2=2.3.4=h5eee18b_0
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libwebp-base=1.3.2=h5eee18b_1
- lz4-c=1.9.4=h6a678d5_1
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py39h5eee18b_1
- mkl_fft=1.3.10=py39h5eee18b_0
- mkl_random=1.2.7=py39h1128e8f_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.5.2=he7f1fd0_0
- openssl=3.0.15=h5eee18b_0
- pillow=10.4.0=py39h5eee18b_0
- pip=24.2=py39h06a4308_0
- pysocks=1.7.1=py39h06a4308_0
- python=3.9.19=h955ad1f_1
- pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0
- pytorch-mutex=1.0=cuda
- readline=8.2=h5eee18b_0
- requests=2.32.3=py39h06a4308_0
- setuptools=75.1.0=py39h06a4308_0
- sqlite=3.45.3=h5eee18b_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.14=h39e8969_0
- torchvision=0.13.1=py39_cu113
- typing_extensions=4.11.0=py39h06a4308_0
- urllib3=2.2.3=py39h06a4308_0
- wheel=0.44.0=py39h06a4308_0
- xz=5.4.6=h5eee18b_1
- zlib=1.2.13=h5eee18b_1
- zstd=1.5.6=hc292b87_0
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "smart"
version = "0.0.0"
description = "Scalable Multi-agent Real-time Motion Generation via Next-token Prediction"
readme = "README.md"
authors = [
{name = "Xiaoxin Feng"},
{name = "Ziyan Gao"},
{name = "Yuheng Kan"}
]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
]
requires-python = ">=3.9"
dependencies = [
"easydict",
"numpy",
"pandas",
"pytorch-lightning",
"scipy",
"torch-cluster",
"torch-geometric",
"torch-scatter",
"torch",
"torchmetrics",
"tqdm",
]
[project.urls]
"Homepage" = "https://smart-motion.github.io/smart/"
"Repository" = "https://github.com/rainmaker22/SMART"
"Paper" = "https://arxiv.org/abs/2405.15677"
[tool.setuptools]
packages = ["smart"]
================================================
FILE: requirements.txt
================================================
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
async-timeout==4.0.3
attrs==24.2.0
contourpy==1.3.0
cycler==0.12.1
easydict==1.13
fonttools==4.54.1
frozenlist==1.4.1
fsspec==2024.10.0
importlib-resources==6.4.5
jinja2==3.1.4
kiwisolver==1.4.7
lightning-utilities==0.11.8
markupsafe==3.0.2
matplotlib==3.9.2
multidict==6.1.0
numpy==1.26.4
packaging==24.1
pandas==2.0.3
propcache==0.2.0
psutil==6.1.0
pyparsing==3.2.0
python-dateutil==2.9.0.post0
pytorch-lightning==2.0.3
pytz==2024.2
pyyaml==6.0.1
scipy==1.10.1
shapely==2.0.6
six==1.16.0
torch-cluster==1.6.0+pt112cu113
torch-geometric==2.6.1
torch-scatter==2.1.0+pt112cu113
torch-sparse==0.6.16+pt112cu113
torch-spline-conv==1.2.1+pt112cu113
torchmetrics==1.5.0
tqdm==4.66.5
tzdata==2024.2
yarl==1.16.0
zipp==3.20.2
waymo-open-dataset-tf-2-12-0==1.6.4
================================================
FILE: scripts/install_pyg.sh
================================================
mkdir pyg_depend && cd pyg_depend
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.1.0%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.16%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
wget https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1%2Bpt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_cluster-1.6.0+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_scatter-2.1.0+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_sparse-0.6.16+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_spline_conv-1.2.1+pt112cu113-cp39-cp39-linux_x86_64.whl
python3 -m pip install torch_geometric
================================================
FILE: scripts/traj_clstering.py
================================================
from smart.utils.geometry import wrap_angle
import numpy as np
def average_distance_vectorized(point_set1, centroids):
dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :])**2, axis=-1))
return np.mean(dists, axis=2)
def assign_clusters(sub_X, centroids):
distances = average_distance_vectorized(sub_X, centroids)
return np.argmin(distances, axis=1)
def Kdisk_cluster(X, N=256, tol=0.035, width=0, length=0, a_pos=None):
S = []
ret_traj_list = []
while len(S) < N:
num_all = X.shape[0]
# 随机选择第一个簇中心
choice_index = np.random.choice(num_all)
x0 = X[choice_index]
if x0[0, 0] < -10 or x0[0, 0] > 50 or x0[0, 1] > 10 or x0[0, 1] < -10:
continue
res_mask = np.sum((X - x0)**2, axis=(1, 2))/4 > (tol**2)
del_mask = np.sum((X - x0)**2, axis=(1, 2))/4 <= (tol**2)
if cal_mean_heading:
del_contour = X[del_mask]
diff_xy = del_contour[:, 0, :] - del_contour[:, 3, :]
del_heading = np.arctan2(diff_xy[:, 1], diff_xy[:, 0]).mean()
x0 = cal_polygon_contour(x0.mean(0)[0], x0.mean(0)[1], del_heading, width, length)
del_traj = a_pos[del_mask]
ret_traj = del_traj.mean(0)[None, ...]
if abs(ret_traj[0, 1, 0] - ret_traj[0, 0, 0]) > 1 and ret_traj[0, 1, 0] < 0:
print(ret_traj)
print('1')
else:
x0 = x0[None, ...]
ret_traj = a_pos[choice_index][None, ...]
X = X[res_mask]
a_pos = a_pos[res_mask]
S.append(x0)
ret_traj_list.append(ret_traj)
centroids = np.concatenate(S, axis=0)
ret_traj = np.concatenate(ret_traj_list, axis=0)
# closest_dist_sq = np.sum((X - centroids[0])**2, axis=(1, 2))
# for k in range(1, K):
# new_dist_sq = np.sum((X - centroids[k - 1])**2, axis=(1, 2))
# closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq)
# probabilities = closest_dist_sq / np.sum(closest_dist_sq)
# centroids[k] = X[np.random.choice(N, p=probabilities)]
return centroids, ret_traj
def cal_polygon_contour(x, y, theta, width, length):
left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
left_front = np.column_stack((left_front_x, left_front_y))
right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
right_front = np.column_stack((right_front_x, right_front_y))
right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
right_back = np.column_stack((right_back_x, right_back_y))
left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
left_back = np.column_stack((left_back_x, left_back_y))
polygon_contour = np.concatenate((left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)
return polygon_contour
if __name__ == '__main__':
shift = 5 # motion token time dimension
num_cluster = 6 # vocabulary size
cal_mean_heading = True
data = {
"veh": np.random.rand(1000, 6, 3),
"cyc": np.random.rand(1000, 6, 3),
"ped": np.random.rand(1000, 6, 3)
}
# Collect the trajectories of all traffic participants from the raw data [NumAgent, shift+1, [relative_x, relative_y, relative_theta]]
nms_res = {}
res = {'token': {}, 'traj': {}, 'token_all': {}}
for k, v in data.items():
# if k != 'veh':
# continue
a_pos = v
print(a_pos.shape)
# a_pos = a_pos[:, shift:1+shift, :]
cal_num = min(int(1e6), a_pos.shape[0])
a_pos = a_pos[np.random.choice(a_pos.shape[0], cal_num, replace=False)]
a_pos[:, :, -1] = wrap_angle(a_pos[:, :, -1])
print(a_pos.shape)
if shift <= 2:
if k == 'veh':
width = 1.0
length = 2.4
elif k == 'cyc':
width = 0.5
length = 1.5
else:
width = 0.5
length = 0.5
else:
if k == 'veh':
width = 2.0
length = 4.8
elif k == 'cyc':
width = 1.0
length = 2.0
else:
width = 1.0
length = 1.0
contour = cal_polygon_contour(a_pos[:, shift, 0], a_pos[:, shift, 1], a_pos[:, shift, 2], width, length)
# plt.figure(figsize=(10, 10))
# for rect in contour:
# rect_closed = np.vstack([rect, rect[0]])
# plt.plot(rect_closed[:, 0], rect_closed[:, 1], linewidth=0.1)
# plt.title("Plot of 256 Rectangles")
# plt.xlabel("x")
# plt.ylabel("y")
# plt.axis('equal')
# plt.savefig(f'src_{k}_new.jpg', dpi=300)
if k == 'veh':
tol = 0.05
elif k == 'cyc':
tol = 0.004
else:
tol = 0.004
centroids, ret_traj = Kdisk_cluster(contour, num_cluster, tol, width, length, a_pos[:, :shift+1])
# plt.figure(figsize=(10, 10))
contour = cal_polygon_contour(ret_traj[:, :, 0].reshape(num_cluster*(shift+1)),
ret_traj[:, :, 1].reshape(num_cluster*(shift+1)),
ret_traj[:, :, 2].reshape(num_cluster*(shift+1)), width, length)
res['token_all'][k] = contour.reshape(num_cluster, (shift+1), 4, 2)
res['token'][k] = centroids
res['traj'][k] = ret_traj
================================================
FILE: smart/__init__.py
================================================
================================================
FILE: smart/datamodules/__init__.py
================================================
from smart.datamodules.scalable_datamodule import MultiDataModule
================================================
FILE: smart/datamodules/scalable_datamodule.py
================================================
from typing import Optional
import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
from smart.datasets.scalable_dataset import MultiDataset
from smart.transforms import WaymoTargetBuilder
class MultiDataModule(pl.LightningDataModule):
transforms = {
"WaymoTargetBuilder": WaymoTargetBuilder,
}
dataset = {
"scalable": MultiDataset,
}
def __init__(self,
root: str,
train_batch_size: int,
val_batch_size: int,
test_batch_size: int,
shuffle: bool = False,
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = True,
train_raw_dir: Optional[str] = None,
val_raw_dir: Optional[str] = None,
test_raw_dir: Optional[str] = None,
train_processed_dir: Optional[str] = None,
val_processed_dir: Optional[str] = None,
test_processed_dir: Optional[str] = None,
transform: Optional[str] = None,
dataset: Optional[str] = None,
num_historical_steps: int = 50,
num_future_steps: int = 60,
processor='ntp',
use_intention=False,
token_size=512,
**kwargs) -> None:
super(MultiDataModule, self).__init__()
self.root = root
self.dataset_class = dataset
self.train_batch_size = train_batch_size
self.val_batch_size = val_batch_size
self.test_batch_size = test_batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers and num_workers > 0
self.train_raw_dir = train_raw_dir
self.val_raw_dir = val_raw_dir
self.test_raw_dir = test_raw_dir
self.train_processed_dir = train_processed_dir
self.val_processed_dir = val_processed_dir
self.test_processed_dir = test_processed_dir
self.processor = processor
self.use_intention = use_intention
self.token_size = token_size
train_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "train")
val_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps, "val")
test_transform = MultiDataModule.transforms[transform](num_historical_steps, num_future_steps)
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
def setup(self, stage: Optional[str] = None) -> None:
self.train_dataset = MultiDataModule.dataset[self.dataset_class](self.root, 'train', processed_dir=self.train_processed_dir,
raw_dir=self.train_raw_dir, processor=self.processor, transform=self.train_transform, token_size=self.token_size)
self.val_dataset = MultiDataModule.dataset[self.dataset_class](None, 'val', processed_dir=self.val_processed_dir,
raw_dir=self.val_raw_dir, processor=self.processor, transform=self.val_transform, token_size=self.token_size)
self.test_dataset = MultiDataModule.dataset[self.dataset_class](None, 'test', processed_dir=self.test_processed_dir,
raw_dir=self.test_raw_dir, processor=self.processor, transform=self.test_transform, token_size=self.token_size)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle,
num_workers=self.num_workers, pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False,
num_workers=self.num_workers, pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.test_batch_size, shuffle=False,
num_workers=self.num_workers, pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers)
================================================
FILE: smart/datasets/__init__.py
================================================
from smart.datasets.scalable_dataset import MultiDataset
================================================
FILE: smart/datasets/preprocess.py
================================================
import torch
import numpy as np
from scipy.interpolate import interp1d
from scipy.spatial.distance import euclidean
import math
import pickle
from smart.utils import wrap_angle
import os
def cal_polygon_contour(x, y, theta, width, length):
left_front_x = x + 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
left_front_y = y + 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
left_front = np.column_stack((left_front_x, left_front_y))
right_front_x = x + 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
right_front_y = y + 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
right_front = np.column_stack((right_front_x, right_front_y))
right_back_x = x - 0.5 * length * np.cos(theta) + 0.5 * width * np.sin(theta)
right_back_y = y - 0.5 * length * np.sin(theta) - 0.5 * width * np.cos(theta)
right_back = np.column_stack((right_back_x, right_back_y))
left_back_x = x - 0.5 * length * np.cos(theta) - 0.5 * width * np.sin(theta)
left_back_y = y - 0.5 * length * np.sin(theta) + 0.5 * width * np.cos(theta)
left_back = np.column_stack((left_back_x, left_back_y))
polygon_contour = np.concatenate(
(left_front[:, None, :], right_front[:, None, :], right_back[:, None, :], left_back[:, None, :]), axis=1)
return polygon_contour
def interplating_polyline(polylines, heading, distance=0.5, split_distace=5):
# Calculate the cumulative distance along the path, up-sample the polyline to 0.5 meter
dist_along_path_list = [[0]]
polylines_list = [[polylines[0]]]
for i in range(1, polylines.shape[0]):
euclidean_dist = euclidean(polylines[i, :2], polylines[i - 1, :2])
heading_diff = min(abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1])),
abs(max(heading[i], heading[i - 1]) - min(heading[1], heading[i - 1]) + math.pi))
if heading_diff > math.pi / 4 and euclidean_dist > 3:
dist_along_path_list.append([0])
polylines_list.append([polylines[i]])
elif heading_diff > math.pi / 8 and euclidean_dist > 3:
dist_along_path_list.append([0])
polylines_list.append([polylines[i]])
elif heading_diff > 0.1 and euclidean_dist > 3:
dist_along_path_list.append([0])
polylines_list.append([polylines[i]])
elif euclidean_dist > 10:
dist_along_path_list.append([0])
polylines_list.append([polylines[i]])
else:
dist_along_path_list[-1].append(dist_along_path_list[-1][-1] + euclidean_dist)
polylines_list[-1].append(polylines[i])
# plt.plot(polylines[:, 0], polylines[:, 1])
# plt.savefig('tmp.jpg')
new_x_list = []
new_y_list = []
multi_polylines_list = []
for idx in range(len(dist_along_path_list)):
if len(dist_along_path_list[idx]) < 2:
continue
dist_along_path = np.array(dist_along_path_list[idx])
polylines_cur = np.array(polylines_list[idx])
# Create interpolation functions for x and y coordinates
fx = interp1d(dist_along_path, polylines_cur[:, 0])
fy = interp1d(dist_along_path, polylines_cur[:, 1])
# fyaw = interp1d(dist_along_path, heading)
# Create an array of distances at which to interpolate
new_dist_along_path = np.arange(0, dist_along_path[-1], distance)
new_dist_along_path = np.concatenate([new_dist_along_path, dist_along_path[[-1]]])
# Use the interpolation functions to generate new x and y coordinates
new_x = fx(new_dist_along_path)
new_y = fy(new_dist_along_path)
# new_yaw = fyaw(new_dist_along_path)
new_x_list.append(new_x)
new_y_list.append(new_y)
# Combine the new x and y coordinates into a single array
new_polylines = np.vstack((new_x, new_y)).T
polyline_size = int(split_distace / distance)
if new_polylines.shape[0] >= (polyline_size + 1):
padding_size = (new_polylines.shape[0] - (polyline_size + 1)) % polyline_size
final_index = (new_polylines.shape[0] - (polyline_size + 1)) // polyline_size + 1
else:
padding_size = new_polylines.shape[0]
final_index = 0
multi_polylines = None
new_polylines = torch.from_numpy(new_polylines)
new_heading = torch.atan2(new_polylines[1:, 1] - new_polylines[:-1, 1],
new_polylines[1:, 0] - new_polylines[:-1, 0])
new_heading = torch.cat([new_heading, new_heading[-1:]], -1)[..., None]
new_polylines = torch.cat([new_polylines, new_heading], -1)
if new_polylines.shape[0] >= (polyline_size + 1):
multi_polylines = new_polylines.unfold(dimension=0, size=polyline_size + 1, step=polyline_size)
multi_polylines = multi_polylines.transpose(1, 2)
multi_polylines = multi_polylines[:, ::5, :]
if padding_size >= 3:
last_polyline = new_polylines[final_index * polyline_size:]
last_polyline = last_polyline[torch.linspace(0, last_polyline.shape[0] - 1, steps=3).long()]
if multi_polylines is not None:
multi_polylines = torch.cat([multi_polylines, last_polyline.unsqueeze(0)], dim=0)
else:
multi_polylines = last_polyline.unsqueeze(0)
if multi_polylines is None:
continue
multi_polylines_list.append(multi_polylines)
if len(multi_polylines_list) > 0:
multi_polylines_list = torch.cat(multi_polylines_list, dim=0)
else:
multi_polylines_list = None
return multi_polylines_list
def average_distance_vectorized(point_set1, centroids):
dists = np.sqrt(np.sum((point_set1[:, None, :, :] - centroids[None, :, :, :]) ** 2, axis=-1))
return np.mean(dists, axis=2)
def assign_clusters(sub_X, centroids):
distances = average_distance_vectorized(sub_X, centroids)
return np.argmin(distances, axis=1)
class TokenProcessor:
def __init__(self, token_size):
module_dir = os.path.dirname(os.path.dirname(__file__))
self.agent_token_path = os.path.join(module_dir, f'tokens/cluster_frame_5_{token_size}.pkl')
self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
self.noise = False
self.disturb = False
self.shift = 5
self.get_trajectory_token()
self.training = False
self.current_step = 10
def preprocess(self, data):
data = self.tokenize_agent(data)
data = self.tokenize_map(data)
del data['city']
if 'polygon_is_intersection' in data['map_polygon']:
del data['map_polygon']['polygon_is_intersection']
if 'route_type' in data['map_polygon']:
del data['map_polygon']['route_type']
return data
def get_trajectory_token(self):
agent_token_data = pickle.load(open(self.agent_token_path, 'rb'))
map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
self.trajectory_token = agent_token_data['token']
self.trajectory_token_all = agent_token_data['token_all']
self.map_token = {'traj_src': map_token_traj['traj_src'], }
self.token_last = {}
for k, v in self.trajectory_token_all.items():
token_last = torch.from_numpy(v[:, -2:]).to(torch.float)
diff_xy = token_last[:, 0, 0] - token_last[:, 0, 3]
theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
cos, sin = theta.cos(), theta.sin()
rot_mat = theta.new_zeros(token_last.shape[0], 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = -sin
rot_mat[:, 1, 0] = sin
rot_mat[:, 1, 1] = cos
agent_token = torch.bmm(token_last[:, 1], rot_mat)
agent_token -= token_last[:, 0].mean(1)[:, None, :]
self.token_last[k] = agent_token.numpy()
def clean_heading(self, data):
heading = data['agent']['heading']
valid = data['agent']['valid_mask']
pi = torch.tensor(torch.pi)
n_vehicles, n_frames = heading.shape
heading_diff_raw = heading[:, :-1] - heading[:, 1:]
heading_diff = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
heading_diff[heading_diff > pi] -= 2 * pi
heading_diff[heading_diff < -pi] += 2 * pi
valid_pairs = valid[:, :-1] & valid[:, 1:]
for i in range(n_frames - 1):
change_needed = (torch.abs(heading_diff[:, i:i + 1]) > 1.0) & valid_pairs[:, i:i + 1]
heading[:, i + 1][change_needed.squeeze()] = heading[:, i][change_needed.squeeze()]
if i < n_frames - 2:
heading_diff_raw = heading[:, i + 1] - heading[:, i + 2]
heading_diff[:, i + 1] = torch.remainder(heading_diff_raw + pi, 2 * pi) - pi
heading_diff[heading_diff[:, i + 1] > pi] -= 2 * pi
heading_diff[heading_diff[:, i + 1] < -pi] += 2 * pi
def tokenize_agent(self, data):
if data['agent']["velocity"].shape[1] == 90:
print(data['scenario_id'], data['agent']["velocity"].shape)
interplote_mask = (data['agent']['valid_mask'][:, self.current_step] == False) * (
data['agent']['position'][:, self.current_step, 0] != 0)
if data['agent']["velocity"].shape[-1] == 2:
data['agent']["velocity"] = torch.cat([data['agent']["velocity"],
torch.zeros(data['agent']["velocity"].shape[0],
data['agent']["velocity"].shape[1], 1)], dim=-1)
vel = data['agent']["velocity"][interplote_mask, self.current_step]
data['agent']['position'][interplote_mask, self.current_step - 1, :3] = data['agent']['position'][
interplote_mask, self.current_step,
:3] - vel * 0.1
data['agent']['valid_mask'][interplote_mask, self.current_step - 1:self.current_step + 1] = True
data['agent']['heading'][interplote_mask, self.current_step - 1] = data['agent']['heading'][
interplote_mask, self.current_step]
data['agent']["velocity"][interplote_mask, self.current_step - 1] = data['agent']["velocity"][
interplote_mask, self.current_step]
data['agent']['type'] = data['agent']['type'].to(torch.uint8)
self.clean_heading(data)
matching_extra_mask = (data['agent']['valid_mask'][:, self.current_step] == True) * (
data['agent']['valid_mask'][:, self.current_step - 5] == False)
interplote_mask_first = (data['agent']['valid_mask'][:, 0] == False) * (data['agent']['position'][:, 0, 0] != 0)
data['agent']['valid_mask'][interplote_mask_first, 0] = True
agent_pos = data['agent']['position'][:, :, :2]
valid_mask = data['agent']['valid_mask']
valid_mask_shift = valid_mask.unfold(1, self.shift + 1, self.shift)
token_valid_mask = valid_mask_shift[:, :, 0] * valid_mask_shift[:, :, -1]
agent_type = data['agent']['type']
agent_category = data['agent']['category']
agent_heading = data['agent']['heading']
vehicle_mask = agent_type == 0
cyclist_mask = agent_type == 2
ped_mask = agent_type == 1
veh_pos = agent_pos[vehicle_mask, :, :]
veh_valid_mask = valid_mask[vehicle_mask, :]
cyc_pos = agent_pos[cyclist_mask, :, :]
cyc_valid_mask = valid_mask[cyclist_mask, :]
ped_pos = agent_pos[ped_mask, :, :]
ped_valid_mask = valid_mask[ped_mask, :]
veh_token_index, veh_token_contour = self.match_token(veh_pos, veh_valid_mask, agent_heading[vehicle_mask],
'veh', agent_category[vehicle_mask],
matching_extra_mask[vehicle_mask])
ped_token_index, ped_token_contour = self.match_token(ped_pos, ped_valid_mask, agent_heading[ped_mask], 'ped',
agent_category[ped_mask], matching_extra_mask[ped_mask])
cyc_token_index, cyc_token_contour = self.match_token(cyc_pos, cyc_valid_mask, agent_heading[cyclist_mask],
'cyc', agent_category[cyclist_mask],
matching_extra_mask[cyclist_mask])
token_index = torch.zeros((agent_pos.shape[0], veh_token_index.shape[1])).to(torch.int64)
token_index[vehicle_mask] = veh_token_index
token_index[ped_mask] = ped_token_index
token_index[cyclist_mask] = cyc_token_index
token_contour = torch.zeros((agent_pos.shape[0], veh_token_contour.shape[1],
veh_token_contour.shape[2], veh_token_contour.shape[3]))
token_contour[vehicle_mask] = veh_token_contour
token_contour[ped_mask] = ped_token_contour
token_contour[cyclist_mask] = cyc_token_contour
trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(torch.float)
trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(torch.float)
trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(torch.float)
agent_token_traj = torch.zeros((agent_pos.shape[0], trajectory_token_veh.shape[0], 4, 2))
agent_token_traj[vehicle_mask] = trajectory_token_veh
agent_token_traj[ped_mask] = trajectory_token_ped
agent_token_traj[cyclist_mask] = trajectory_token_cyc
if not self.training:
token_valid_mask[matching_extra_mask, 1] = True
data['agent']['token_idx'] = token_index
data['agent']['token_contour'] = token_contour
token_pos = token_contour.mean(dim=2)
data['agent']['token_pos'] = token_pos
diff_xy = token_contour[:, :, 0, :] - token_contour[:, :, 3, :]
data['agent']['token_heading'] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
data['agent']['agent_valid_mask'] = token_valid_mask
vel = torch.cat([token_pos.new_zeros(data['agent']['num_nodes'], 1, 2),
((token_pos[:, 1:] - token_pos[:, :-1]) / (0.1 * self.shift))], dim=1)
vel_valid_mask = torch.cat([torch.zeros(token_valid_mask.shape[0], 1, dtype=torch.bool),
(token_valid_mask * token_valid_mask.roll(shifts=1, dims=1))[:, 1:]], dim=1)
vel[~vel_valid_mask] = 0
vel[data['agent']['valid_mask'][:, self.current_step], 1] = data['agent']['velocity'][
data['agent']['valid_mask'][:, self.current_step],
self.current_step, :2]
data['agent']['token_velocity'] = vel
return data
def match_token(self, pos, valid_mask, heading, category, agent_category, extra_mask):
agent_token_src = self.trajectory_token[category]
token_last = self.token_last[category]
if self.shift <= 2:
if category == 'veh':
width = 1.0
length = 2.4
elif category == 'cyc':
width = 0.5
length = 1.5
else:
width = 0.5
length = 0.5
else:
if category == 'veh':
width = 2.0
length = 4.8
elif category == 'cyc':
width = 1.0
length = 2.0
else:
width = 1.0
length = 1.0
prev_heading = heading[:, 0]
prev_pos = pos[:, 0]
agent_num, num_step, feat_dim = pos.shape
token_num, token_contour_dim, feat_dim = agent_token_src.shape
agent_token_src = agent_token_src.reshape(1, token_num * token_contour_dim, feat_dim).repeat(agent_num, 0)
token_last = token_last.reshape(1, token_num * token_contour_dim, feat_dim).repeat(extra_mask.sum(), 0)
token_index_list = []
token_contour_list = []
prev_token_idx = None
for i in range(self.shift, pos.shape[1], self.shift):
theta = prev_heading
cur_heading = heading[:, i]
cur_pos = pos[:, i]
cos, sin = theta.cos(), theta.sin()
rot_mat = theta.new_zeros(agent_num, 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
agent_token_world = torch.bmm(torch.from_numpy(agent_token_src).to(torch.float), rot_mat).reshape(agent_num,
token_num,
token_contour_dim,
feat_dim)
agent_token_world += prev_pos[:, None, None, :]
cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
agent_token_index = torch.from_numpy(np.argmin(
np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
axis=-1))
if prev_token_idx is not None and self.noise:
same_idx = prev_token_idx == agent_token_index
same_idx[:] = True
topk_indices = np.argsort(
np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)),
axis=2), axis=-1)[:, :5]
sample_topk = np.random.choice(range(0, topk_indices.shape[1]), topk_indices.shape[0])
agent_token_index[same_idx] = \
torch.from_numpy(topk_indices[np.arange(topk_indices.shape[0]), sample_topk])[same_idx]
token_contour_select = agent_token_world[torch.arange(agent_num), agent_token_index]
diff_xy = token_contour_select[:, 0, :] - token_contour_select[:, 3, :]
prev_heading = heading[:, i].clone()
prev_heading[valid_mask[:, i - self.shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[
valid_mask[:, i - self.shift]]
prev_pos = pos[:, i].clone()
prev_pos[valid_mask[:, i - self.shift]] = token_contour_select.mean(dim=1)[valid_mask[:, i - self.shift]]
prev_token_idx = agent_token_index
token_index_list.append(agent_token_index[:, None])
token_contour_list.append(token_contour_select[:, None, ...])
token_index = torch.cat(token_index_list, dim=1)
token_contour = torch.cat(token_contour_list, dim=1)
# extra matching
if not self.training:
theta = heading[extra_mask, self.current_step - 1]
prev_pos = pos[extra_mask, self.current_step - 1]
cur_pos = pos[extra_mask, self.current_step]
cur_heading = heading[extra_mask, self.current_step]
cos, sin = theta.cos(), theta.sin()
rot_mat = theta.new_zeros(extra_mask.sum(), 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
agent_token_world = torch.bmm(torch.from_numpy(token_last).to(torch.float), rot_mat).reshape(
extra_mask.sum(), token_num, token_contour_dim, feat_dim)
agent_token_world += prev_pos[:, None, None, :]
cur_contour = cal_polygon_contour(cur_pos[:, 0], cur_pos[:, 1], cur_heading, width, length)
agent_token_index = torch.from_numpy(np.argmin(
np.mean(np.sqrt(np.sum((cur_contour[:, None, ...] - agent_token_world.numpy()) ** 2, axis=-1)), axis=2),
axis=-1))
token_contour_select = agent_token_world[torch.arange(extra_mask.sum()), agent_token_index]
token_index[extra_mask, 1] = agent_token_index
token_contour[extra_mask, 1] = token_contour_select
return token_index, token_contour
def tokenize_map(self, data):
data['map_polygon']['type'] = data['map_polygon']['type'].to(torch.uint8)
data['map_point']['type'] = data['map_point']['type'].to(torch.uint8)
pt2pl = data[('map_point', 'to', 'map_polygon')]['edge_index']
pt_type = data['map_point']['type'].to(torch.uint8)
pt_side = torch.zeros_like(pt_type)
pt_pos = data['map_point']['position'][:, :2]
data['map_point']['orientation'] = wrap_angle(data['map_point']['orientation'])
pt_heading = data['map_point']['orientation']
split_polyline_type = []
split_polyline_pos = []
split_polyline_theta = []
split_polyline_side = []
pl_idx_list = []
split_polygon_type = []
data['map_point']['type'].unique()
for i in sorted(np.unique(pt2pl[1])):
index = pt2pl[0, pt2pl[1] == i]
polygon_type = data['map_polygon']["type"][i]
cur_side = pt_side[index]
cur_type = pt_type[index]
cur_pos = pt_pos[index]
cur_heading = pt_heading[index]
for side_val in np.unique(cur_side):
for type_val in np.unique(cur_type):
if type_val == 13:
continue
indices = np.where((cur_side == side_val) & (cur_type == type_val))[0]
if len(indices) <= 2:
continue
split_polyline = interplating_polyline(cur_pos[indices].numpy(), cur_heading[indices].numpy())
if split_polyline is None:
continue
new_cur_type = cur_type[indices][0]
new_cur_side = cur_side[indices][0]
map_polygon_type = polygon_type.repeat(split_polyline.shape[0])
new_cur_type = new_cur_type.repeat(split_polyline.shape[0])
new_cur_side = new_cur_side.repeat(split_polyline.shape[0])
cur_pl_idx = torch.Tensor([i])
new_cur_pl_idx = cur_pl_idx.repeat(split_polyline.shape[0])
split_polyline_pos.append(split_polyline[..., :2])
split_polyline_theta.append(split_polyline[..., 2])
split_polyline_type.append(new_cur_type)
split_polyline_side.append(new_cur_side)
pl_idx_list.append(new_cur_pl_idx)
split_polygon_type.append(map_polygon_type)
split_polyline_pos = torch.cat(split_polyline_pos, dim=0)
split_polyline_theta = torch.cat(split_polyline_theta, dim=0)
split_polyline_type = torch.cat(split_polyline_type, dim=0)
split_polyline_side = torch.cat(split_polyline_side, dim=0)
split_polygon_type = torch.cat(split_polygon_type, dim=0)
pl_idx_list = torch.cat(pl_idx_list, dim=0)
vec = split_polyline_pos[:, 1, :] - split_polyline_pos[:, 0, :]
data['map_save'] = {}
data['pt_token'] = {}
data['map_save']['traj_pos'] = split_polyline_pos
data['map_save']['traj_theta'] = split_polyline_theta[:, 0] # torch.arctan2(vec[:, 1], vec[:, 0])
data['map_save']['pl_idx_list'] = pl_idx_list
data['pt_token']['type'] = split_polyline_type
data['pt_token']['side'] = split_polyline_side
data['pt_token']['pl_type'] = split_polygon_type
data['pt_token']['num_nodes'] = split_polyline_pos.shape[0]
return data
================================================
FILE: smart/datasets/scalable_dataset.py
================================================
import os
import pickle
from typing import Callable, List, Optional, Tuple, Union
import pandas as pd
from torch_geometric.data import Dataset
from smart.utils.log import Logging
import numpy as np
from .preprocess import TokenProcessor
def distance(point1, point2):
return np.sqrt((point2[0] - point1[0])**2 + (point2[1] - point1[1])**2)
class MultiDataset(Dataset):
def __init__(self,
root: str,
split: str,
raw_dir: List[str] = None,
processed_dir: List[str] = None,
transform: Optional[Callable] = None,
dim: int = 3,
num_historical_steps: int = 50,
num_future_steps: int = 60,
predict_unseen_agents: bool = False,
vector_repr: bool = True,
cluster: bool = False,
processor=None,
use_intention=False,
token_size=512) -> None:
self.logger = Logging().log(level='DEBUG')
self.root = root
self.well_done = [0]
if split not in ('train', 'val', 'test'):
raise ValueError(f'{split} is not a valid split')
self.split = split
self.training = split == 'train'
self.logger.debug("Starting loading dataset")
self._raw_file_names = []
self._raw_paths = []
self._raw_file_dataset = []
if raw_dir is not None:
self._raw_dir = raw_dir
for raw_dir in self._raw_dir:
raw_dir = os.path.expanduser(os.path.normpath(raw_dir))
dataset = "waymo"
file_list = os.listdir(raw_dir)
self._raw_file_names.extend(file_list)
self._raw_paths.extend([os.path.join(raw_dir, f) for f in file_list])
self._raw_file_dataset.extend([dataset for _ in range(len(file_list))])
if self.root is not None:
split_datainfo = os.path.join(root, "split_datainfo.pkl")
with open(split_datainfo, 'rb+') as f:
split_datainfo = pickle.load(f)
if split == "test":
split = "val"
self._processed_file_names = split_datainfo[split]
self.dim = dim
self.num_historical_steps = num_historical_steps
self._num_samples = len(self._processed_file_names) - 1 if processed_dir is not None else len(self._raw_file_names)
self.logger.debug("The number of {} dataset is ".format(split) + str(self._num_samples))
self.token_processor = TokenProcessor(2048)
super(MultiDataset, self).__init__(root=root, transform=transform, pre_transform=None, pre_filter=None)
@property
def raw_dir(self) -> str:
return self._raw_dir
@property
def raw_paths(self) -> List[str]:
return self._raw_paths
@property
def raw_file_names(self) -> Union[str, List[str], Tuple]:
return self._raw_file_names
@property
def processed_file_names(self) -> Union[str, List[str], Tuple]:
return self._processed_file_names
def len(self) -> int:
return self._num_samples
def generate_ref_token(self):
pass
def get(self, idx: int):
with open(self.raw_paths[idx], 'rb') as handle:
data = pickle.load(handle)
data = self.token_processor.preprocess(data)
return data
================================================
FILE: smart/layers/__init__.py
================================================
from smart.layers.attention_layer import AttentionLayer
from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
from smart.layers.mlp_layer import MLPLayer
================================================
FILE: smart/layers/attention_layer.py
================================================
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
from smart.utils import weight_init
class AttentionLayer(MessagePassing):
def __init__(self,
hidden_dim: int,
num_heads: int,
head_dim: int,
dropout: float,
bipartite: bool,
has_pos_emb: bool,
**kwargs) -> None:
super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs)
self.num_heads = num_heads
self.head_dim = head_dim
self.has_pos_emb = has_pos_emb
self.scale = head_dim ** -0.5
self.to_q = nn.Linear(hidden_dim, head_dim * num_heads)
self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
self.to_v = nn.Linear(hidden_dim, head_dim * num_heads)
if has_pos_emb:
self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False)
self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads)
self.to_s = nn.Linear(hidden_dim, head_dim * num_heads)
self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads)
self.to_out = nn.Linear(head_dim * num_heads, hidden_dim)
self.attn_drop = nn.Dropout(dropout)
self.ff_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim),
)
if bipartite:
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim)
else:
self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim)
self.attn_prenorm_x_dst = self.attn_prenorm_x_src
if has_pos_emb:
self.attn_prenorm_r = nn.LayerNorm(hidden_dim)
self.attn_postnorm = nn.LayerNorm(hidden_dim)
self.ff_prenorm = nn.LayerNorm(hidden_dim)
self.ff_postnorm = nn.LayerNorm(hidden_dim)
self.apply(weight_init)
def forward(self,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
r: Optional[torch.Tensor],
edge_index: torch.Tensor) -> torch.Tensor:
if isinstance(x, torch.Tensor):
x_src = x_dst = self.attn_prenorm_x_src(x)
else:
x_src, x_dst = x
x_src = self.attn_prenorm_x_src(x_src)
x_dst = self.attn_prenorm_x_dst(x_dst)
x = x[1]
if self.has_pos_emb and r is not None:
r = self.attn_prenorm_r(r)
x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index))
x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x)))
return x
def message(self,
q_i: torch.Tensor,
k_j: torch.Tensor,
v_j: torch.Tensor,
r: Optional[torch.Tensor],
index: torch.Tensor,
ptr: Optional[torch.Tensor]) -> torch.Tensor:
if self.has_pos_emb and r is not None:
k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim)
v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim)
sim = (q_i * k_j).sum(dim=-1) * self.scale
attn = softmax(sim, index, ptr)
self.attention_weight = attn.sum(-1).detach()
attn = self.attn_drop(attn)
return v_j * attn.unsqueeze(-1)
def update(self,
inputs: torch.Tensor,
x_dst: torch.Tensor) -> torch.Tensor:
inputs = inputs.view(-1, self.num_heads * self.head_dim)
g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1)))
return inputs + g * (self.to_s(x_dst) - inputs)
def _attn_block(self,
x_src: torch.Tensor,
x_dst: torch.Tensor,
r: Optional[torch.Tensor],
edge_index: torch.Tensor) -> torch.Tensor:
q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim)
k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim)
v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim)
agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r)
return self.to_out(agg)
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
return self.ff_mlp(x)
================================================
FILE: smart/layers/fourier_embedding.py
================================================
import math
from typing import List, Optional
import torch
import torch.nn as nn
from smart.utils import weight_init
class FourierEmbedding(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int,
num_freq_bands: int) -> None:
super(FourierEmbedding, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None
self.mlps = nn.ModuleList(
[nn.Sequential(
nn.Linear(num_freq_bands * 2 + 1, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
)
for _ in range(input_dim)])
self.to_out = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim),
)
self.apply(weight_init)
def forward(self,
continuous_inputs: Optional[torch.Tensor] = None,
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
if continuous_inputs is None:
if categorical_embs is not None:
x = torch.stack(categorical_embs).sum(dim=0)
else:
raise ValueError('Both continuous_inputs and categorical_embs are None')
else:
x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi
# Warning: if your data are noisy, don't use learnable sinusoidal embedding
x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1)
continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim
for i in range(self.input_dim):
continuous_embs[i] = self.mlps[i](x[:, i])
x = torch.stack(continuous_embs).sum(dim=0)
if categorical_embs is not None:
x = x + torch.stack(categorical_embs).sum(dim=0)
return self.to_out(x)
class MLPEmbedding(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int) -> None:
super(MLPEmbedding, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, 128),
nn.LayerNorm(128),
nn.ReLU(inplace=True),
nn.Linear(128, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, hidden_dim))
self.apply(weight_init)
def forward(self,
continuous_inputs: Optional[torch.Tensor] = None,
categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
if continuous_inputs is None:
if categorical_embs is not None:
x = torch.stack(categorical_embs).sum(dim=0)
else:
raise ValueError('Both continuous_inputs and categorical_embs are None')
else:
x = self.mlp(continuous_inputs)
if categorical_embs is not None:
x = x + torch.stack(categorical_embs).sum(dim=0)
return x
================================================
FILE: smart/layers/mlp_layer.py
================================================
import torch
import torch.nn as nn
from smart.utils import weight_init
class MLPLayer(nn.Module):
def __init__(self,
input_dim: int,
hidden_dim: int,
output_dim: int) -> None:
super(MLPLayer, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, output_dim),
)
self.apply(weight_init)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
================================================
FILE: smart/metrics/__init__.py
================================================
from smart.metrics.average_meter import AverageMeter
from smart.metrics.min_ade import minADE
from smart.metrics.min_fde import minFDE
from smart.metrics.next_token_cls import TokenCls
================================================
FILE: smart/metrics/average_meter.py
================================================
import torch
from torchmetrics import Metric
class AverageMeter(Metric):
def __init__(self, **kwargs) -> None:
super(AverageMeter, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
def update(self, val: torch.Tensor) -> None:
self.sum += val.sum()
self.count += val.numel()
def compute(self) -> torch.Tensor:
return self.sum / self.count
================================================
FILE: smart/metrics/min_ade.py
================================================
from typing import Optional
import torch
from torchmetrics import Metric
from smart.metrics.utils import topk
from smart.metrics.utils import valid_filter
class minMultiADE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minMultiADE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True,
min_criterion: str = 'FDE') -> None:
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
pred_topk, _ = topk(self.max_guesses, pred, prob)
if min_criterion == 'FDE':
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
inds_best = torch.norm(
pred_topk[torch.arange(pred.size(0)), :, inds_last] -
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
elif min_criterion == 'ADE':
self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) *
valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
else:
raise ValueError('{} is not a valid criterion'.format(min_criterion))
self.count += pred.size(0)
def compute(self) -> torch.Tensor:
return self.sum / self.count
class minADE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minADE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
self.eval_timestep = 70
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True,
min_criterion: str = 'ADE') -> None:
# pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
# pred_topk, _ = topk(self.max_guesses, pred, prob)
# if min_criterion == 'FDE':
# inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
# inds_best = torch.norm(
# pred[torch.arange(pred.size(0)), :, inds_last] -
# target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
# self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
# valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
# elif min_criterion == 'ADE':
# self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) *
# valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
# else:
# raise ValueError('{} is not a valid criterion'.format(min_criterion))
eval_timestep = min(self.eval_timestep, pred.shape[1])
self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum()
self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum()
def compute(self) -> torch.Tensor:
return self.sum / self.count
================================================
FILE: smart/metrics/min_fde.py
================================================
from typing import Optional
import torch
from torchmetrics import Metric
from smart.metrics.utils import topk
from smart.metrics.utils import valid_filter
class minMultiFDE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minMultiFDE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True) -> None:
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
pred_topk, _ = topk(self.max_guesses, pred, prob)
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] -
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2),
p=2, dim=-1).min(dim=-1)[0].sum()
self.count += pred.size(0)
def compute(self) -> torch.Tensor:
return self.sum / self.count
class minFDE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minFDE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
self.eval_timestep = 70
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True) -> None:
eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1
self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) *
valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum()
self.count += valid_mask[:, eval_timestep-1].sum()
def compute(self) -> torch.Tensor:
return self.sum / self.count
================================================
FILE: smart/metrics/next_token_cls.py
================================================
from typing import Optional
import torch
from torchmetrics import Metric
from smart.metrics.utils import topk
from smart.metrics.utils import valid_filter
class TokenCls(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(TokenCls, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None) -> None:
target = target[..., None]
acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask
self.sum += acc.sum()
self.count += valid_mask.sum()
def compute(self) -> torch.Tensor:
return self.sum / self.count
================================================
FILE: smart/metrics/utils.py
================================================
from typing import Optional, Tuple
import torch
from torch_scatter import gather_csr
from torch_scatter import segment_csr
def topk(
max_guesses: int,
pred: torch.Tensor,
prob: Optional[torch.Tensor] = None,
ptr: Optional[torch.Tensor] = None,
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
max_guesses = min(max_guesses, pred.size(1))
if max_guesses == pred.size(1):
if prob is not None:
prob = prob / prob.sum(dim=-1, keepdim=True)
else:
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred, prob
else:
if prob is not None:
if joint:
if ptr is None:
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = inds_topk.repeat(pred.size(0), 1)
else:
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
reduce='mean'),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = gather_csr(src=inds_topk, indptr=ptr)
else:
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
else:
pred_topk = pred[:, :max_guesses]
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred_topk, prob_topk
def topkind(
max_guesses: int,
pred: torch.Tensor,
prob: Optional[torch.Tensor] = None,
ptr: Optional[torch.Tensor] = None,
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
max_guesses = min(max_guesses, pred.size(1))
if max_guesses == pred.size(1):
if prob is not None:
prob = prob / prob.sum(dim=-1, keepdim=True)
else:
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred, prob, None
else:
if prob is not None:
if joint:
if ptr is None:
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = inds_topk.repeat(pred.size(0), 1)
else:
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
reduce='mean'),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = gather_csr(src=inds_topk, indptr=ptr)
else:
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
else:
pred_topk = pred[:, :max_guesses]
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred_topk, prob_topk, inds_topk
def valid_filter(
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
ptr: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
torch.Tensor, torch.Tensor]:
if valid_mask is None:
valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool)
if keep_invalid_final_step:
filter_mask = valid_mask.any(dim=-1)
else:
filter_mask = valid_mask[:, -1]
pred = pred[filter_mask]
target = target[filter_mask]
if prob is not None:
prob = prob[filter_mask]
valid_mask = valid_mask[filter_mask]
if ptr is not None:
num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum')
ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,))
torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:])
else:
ptr = target.new_tensor([0, target.size(0)])
return pred, target, prob, valid_mask, ptr
def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):
"""
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 7)
pred_scores (batch_size, num_modes):
dist_thresh (float):
num_ret_modes (int, optional): Defaults to 6.
Returns:
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
ret_scores (batch_size, num_ret_modes)
ret_idxs (batch_size, num_ret_modes)
"""
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
pred_goals = pred_trajs[:, :, -1, :]
dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1)
nearby_neighbor = dist < dist_thresh
pred_scores = nearby_neighbor.sum(dim=-1) / num_modes
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
point_cover_mask = (dist < dist_thresh)
point_val = sorted_pred_scores.clone() # (batch_size, N)
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
for k in range(num_ret_modes):
cur_idx = point_val.argmax(dim=-1) # (batch_size)
ret_idxs[:, k] = cur_idx
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
point_val_selected[bs_idxs, cur_idx] = -1
point_val += point_val_selected
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
return ret_trajs, ret_scores, ret_idxs
def batch_nms(pred_trajs, pred_scores,
dist_thresh, num_ret_modes=6,
mode='static', speed=None):
"""
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 7)
pred_scores (batch_size, num_modes):
dist_thresh (float):
num_ret_modes (int, optional): Defaults to 6.
Returns:
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
ret_scores (batch_size, num_ret_modes)
ret_idxs (batch_size, num_ret_modes)
"""
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
if mode == "speed":
scale = torch.ones(batch_size).to(sorted_pred_goals.device)
lon_dist_thresh = 4 * scale
lat_dist_thresh = 0.5 * scale
lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1)
lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1)
point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None])
else:
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
point_cover_mask = (dist < dist_thresh)
point_val = sorted_pred_scores.clone() # (batch_size, N)
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
for k in range(num_ret_modes):
cur_idx = point_val.argmax(dim=-1) # (batch_size)
ret_idxs[:, k] = cur_idx
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
point_val_selected[bs_idxs, cur_idx] = -1
point_val += point_val_selected
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
return ret_trajs, ret_scores, ret_idxs
def batch_nms_token(pred_trajs, pred_scores,
dist_thresh, num_ret_modes=6,
mode='static', speed=None):
"""
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 7)
pred_scores (batch_size, num_modes):
dist_thresh (float):
num_ret_modes (int, optional): Defaults to 6.
Returns:
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
ret_scores (batch_size, num_ret_modes)
ret_idxs (batch_size, num_ret_modes)
"""
batch_size, num_modes, num_feat_dim = pred_trajs.shape
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
if mode == "nearby":
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
values, indices = torch.topk(dist, 5, dim=-1, largest=False)
thresh_hold = values[..., -1]
point_cover_mask = dist < thresh_hold[..., None]
else:
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
point_cover_mask = (dist < dist_thresh)
point_val = sorted_pred_scores.clone() # (batch_size, N)
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim)
ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes)
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
for k in range(num_ret_modes):
cur_idx = point_val.argmax(dim=-1) # (batch_size)
ret_idxs[:, k] = cur_idx
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
point_val_selected[bs_idxs, cur_idx] = -1
point_val += point_val_selected
ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx]
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
return ret_goals, ret_scores, ret_idxs
================================================
FILE: smart/model/__init__.py
================================================
from smart.model.smart import SMART
================================================
FILE: smart/model/smart.py
================================================
import contextlib
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch_geometric.data import Batch
from torch_geometric.data import HeteroData
from smart.metrics import minADE
from smart.metrics import minFDE
from smart.metrics import TokenCls
from smart.modules import SMARTDecoder
from torch.optim.lr_scheduler import LambdaLR
import math
import numpy as np
import pickle
from collections import defaultdict
import os
from waymo_open_dataset.protos import sim_agents_submission_pb2
def cal_polygon_contour(x, y, theta, width, length):
left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_front = (left_front_x, left_front_y)
right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_front = (right_front_x, right_front_y)
right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_back = (right_back_x, right_back_y)
left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_back = (left_back_x, left_back_y)
polygon_contour = [left_front, right_front, right_back, left_back]
return polygon_contour
def joint_scene_from_states(states, object_ids) -> sim_agents_submission_pb2.JointScene:
states = states.numpy()
simulated_trajectories = []
for i_object in range(len(object_ids)):
simulated_trajectories.append(sim_agents_submission_pb2.SimulatedTrajectory(
center_x=states[i_object, :, 0], center_y=states[i_object, :, 1],
center_z=states[i_object, :, 2], heading=states[i_object, :, 3],
object_id=object_ids[i_object].item()
))
return sim_agents_submission_pb2.JointScene(simulated_trajectories=simulated_trajectories)
class SMART(pl.LightningModule):
def __init__(self, model_config) -> None:
super(SMART, self).__init__()
self.save_hyperparameters()
self.model_config = model_config
self.warmup_steps = model_config.warmup_steps
self.lr = model_config.lr
self.total_steps = model_config.total_steps
self.dataset = model_config.dataset
self.input_dim = model_config.input_dim
self.hidden_dim = model_config.hidden_dim
self.output_dim = model_config.output_dim
self.output_head = model_config.output_head
self.num_historical_steps = model_config.num_historical_steps
self.num_future_steps = model_config.decoder.num_future_steps
self.num_freq_bands = model_config.num_freq_bands
self.vis_map = False
self.noise = True
module_dir = os.path.dirname(os.path.dirname(__file__))
self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
self.init_map_token()
self.token_path = os.path.join(module_dir, 'tokens/cluster_frame_5_2048.pkl')
token_data = self.get_trajectory_token()
self.encoder = SMARTDecoder(
dataset=model_config.dataset,
input_dim=model_config.input_dim,
hidden_dim=model_config.hidden_dim,
num_historical_steps=model_config.num_historical_steps,
num_freq_bands=model_config.num_freq_bands,
num_heads=model_config.num_heads,
head_dim=model_config.head_dim,
dropout=model_config.dropout,
num_map_layers=model_config.decoder.num_map_layers,
num_agent_layers=model_config.decoder.num_agent_layers,
pl2pl_radius=model_config.decoder.pl2pl_radius,
pl2a_radius=model_config.decoder.pl2a_radius,
a2a_radius=model_config.decoder.a2a_radius,
time_span=model_config.decoder.time_span,
map_token={'traj_src': self.map_token['traj_src']},
token_data=token_data,
token_size=model_config.decoder.token_size
)
self.minADE = minADE(max_guesses=1)
self.minFDE = minFDE(max_guesses=1)
self.TokenCls = TokenCls(max_guesses=1)
self.test_predictions = dict()
self.cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
self.map_cls_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
self.inference_token = False
self.rollout_num = 1
def get_trajectory_token(self):
token_data = pickle.load(open(self.token_path, 'rb'))
self.trajectory_token = token_data['token']
self.trajectory_token_traj = token_data['traj']
self.trajectory_token_all = token_data['token_all']
return token_data
def init_map_token(self):
self.argmin_sample_len = 3
map_token_traj = pickle.load(open(self.map_token_traj_path, 'rb'))
self.map_token = {'traj_src': map_token_traj['traj_src'], }
traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1],
self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0])
indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long()
self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float)
self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float)
def forward(self, data: HeteroData):
res = self.encoder(data)
return res
def inference(self, data: HeteroData):
res = self.encoder.inference(data)
return res
def maybe_autocast(self, dtype=torch.float16):
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
def training_step(self,
data,
batch_idx):
data = self.match_token_map(data)
data = self.sample_pt_pred(data)
if isinstance(data, Batch):
data['agent']['av_index'] += data['agent']['ptr'][:-1]
pred = self(data)
next_token_prob = pred['next_token_prob']
next_token_idx_gt = pred['next_token_idx_gt']
next_token_eval_mask = pred['next_token_eval_mask']
cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
loss = cls_loss
self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
self.log('cls_loss', cls_loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
return loss
def validation_step(self,
data,
batch_idx):
data = self.match_token_map(data)
data = self.sample_pt_pred(data)
if isinstance(data, Batch):
data['agent']['av_index'] += data['agent']['ptr'][:-1]
pred = self(data)
next_token_idx = pred['next_token_idx']
next_token_idx_gt = pred['next_token_idx_gt']
next_token_eval_mask = pred['next_token_eval_mask']
next_token_prob = pred['next_token_prob']
cls_loss = self.cls_loss(next_token_prob[next_token_eval_mask], next_token_idx_gt[next_token_eval_mask])
loss = cls_loss
self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
valid_mask=next_token_eval_mask[next_token_eval_mask])
self.log('val_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1] # * (data['agent']['category'] == 3)
if self.inference_token:
pred = self.inference(data)
pos_a = pred['pos_a']
gt = pred['gt']
valid_mask = data['agent']['valid_mask'][:, self.num_historical_steps:]
pred_traj = pred['pred_traj']
# next_token_idx = pred['next_token_idx'][..., None]
# next_token_idx_gt = pred['next_token_idx_gt'][:, 2:]
# next_token_eval_mask = pred['next_token_eval_mask'][:, 2:]
# next_token_eval_mask[:, 1:] = False
# self.TokenCls.update(pred=next_token_idx[next_token_eval_mask], target=next_token_idx_gt[next_token_eval_mask],
# valid_mask=next_token_eval_mask[next_token_eval_mask])
# self.log('val_inference_cls_acc', self.TokenCls, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, sync_dist=True)
eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps-1]
self.minADE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])
self.minFDE.update(pred=pred_traj[eval_mask], target=gt[eval_mask], valid_mask=valid_mask[eval_mask])
# print('ade: ', self.minADE.compute(), 'fde: ', self.minFDE.compute())
self.log('val_minADE', self.minADE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)
self.log('val_minFDE', self.minFDE, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)
def on_validation_start(self):
self.gt = []
self.pred = []
self.scenario_rollouts = []
self.batch_metric = defaultdict(list)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
def lr_lambda(current_step):
if current_step + 1 < self.warmup_steps:
return float(current_step + 1) / float(max(1, self.warmup_steps))
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))))
)
lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
return [optimizer], [lr_scheduler]
def load_params_from_file(self, filename, logger, to_cpu=False):
if not os.path.isfile(filename):
raise FileNotFoundError
logger.info('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
loc_type = torch.device('cpu') if to_cpu else None
checkpoint = torch.load(filename, map_location=loc_type)
model_state_disk = checkpoint['state_dict']
version = checkpoint.get("version", None)
if version is not None:
logger.info('==> Checkpoint trained from version: %s' % version)
logger.info(f'The number of disk ckpt keys: {len(model_state_disk)}')
model_state = self.state_dict()
model_state_disk_filter = {}
for key, val in model_state_disk.items():
if key in model_state and model_state_disk[key].shape == model_state[key].shape:
model_state_disk_filter[key] = val
else:
if key not in model_state:
print(f'Ignore key in disk (not found in model): {key}, shape={val.shape}')
else:
print(f'Ignore key in disk (shape does not match): {key}, load_shape={val.shape}, model_shape={model_state[key].shape}')
model_state_disk = model_state_disk_filter
missing_keys, unexpected_keys = self.load_state_dict(model_state_disk, strict=False)
logger.info(f'Missing keys: {missing_keys}')
logger.info(f'The number of missing keys: {len(missing_keys)}')
logger.info(f'The number of unexpected keys: {len(unexpected_keys)}')
logger.info('==> Done (total keys %d)' % (len(model_state)))
epoch = checkpoint.get('epoch', -1)
it = checkpoint.get('it', 0.0)
return it, epoch
def match_token_map(self, data):
traj_pos = data['map_save']['traj_pos'].to(torch.float)
traj_theta = data['map_save']['traj_theta'].to(torch.float)
pl_idx_list = data['map_save']['pl_idx_list']
token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device)
token_src = self.map_token['traj_src'].to(traj_pos.device)
max_traj_len = self.map_token['traj_src'].shape[1]
pl_num = traj_pos.shape[0]
pt_token_pos = traj_pos[:, 0, :].clone()
pt_token_orientation = traj_theta.clone()
cos, sin = traj_theta.cos(), traj_theta.sin()
rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
rot_mat[..., 0, 0] = cos
rot_mat[..., 0, 1] = -sin
rot_mat[..., 1, 0] = sin
rot_mat[..., 1, 1] = cos
traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1))
pt_token_id = torch.argmin(distance, dim=1)
if self.noise:
topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]
sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)
cos, sin = traj_theta.cos(), traj_theta.sin()
rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
rot_mat[..., 0, 0] = cos
rot_mat[..., 0, 1] = sin
rot_mat[..., 1, 0] = -sin
rot_mat[..., 1, 1] = cos
token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)
pl_idx_full = pl_idx_list.clone()
token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
count_nums = []
for pl in pl_idx_full.unique():
pt = token2pl[0, token2pl[1, :] == pl]
left_side = (data['pt_token']['side'][pt] == 0).sum()
right_side = (data['pt_token']['side'][pt] == 1).sum()
center_side = (data['pt_token']['side'][pt] == 2).sum()
count_nums.append(torch.Tensor([left_side, right_side, center_side]))
count_nums = torch.stack(count_nums, dim=0)
num_polyline = int(count_nums.max().item())
traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1) #
counts_num_expanded = count_nums.unsqueeze(-1)
mask_update = idx_matrix < counts_num_expanded
traj_mask[mask_update] = True
data['pt_token']['traj_mask'] = traj_mask
data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
device=traj_pos.device, dtype=torch.float)], dim=-1)
data['pt_token']['orientation'] = pt_token_orientation
data['pt_token']['height'] = data['pt_token']['position'][:, -1]
data[('pt_token', 'to', 'map_polygon')] = {}
data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl
data['pt_token']['token_idx'] = pt_token_id
return data
def sample_pt_pred(self, data):
traj_mask = data['pt_token']['traj_mask']
raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)
masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)]
masked_pt_index = torch.sort(masked_pt_index, -1)[0]
pt_valid_mask = traj_mask.clone()
pt_valid_mask.scatter_(2, masked_pt_index, False)
pt_pred_mask = traj_mask.clone()
pt_pred_mask.scatter_(2, masked_pt_index, False)
tmp_mask = pt_pred_mask.clone()
tmp_mask[:, :, :] = True
tmp_mask.scatter_(2, masked_pt_index-1, False)
pt_pred_mask.masked_fill_(tmp_mask, False)
pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)
pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)
data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]
data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]
data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]
return data
================================================
FILE: smart/modules/__init__.py
================================================
from smart.modules.smart_decoder import SMARTDecoder
from smart.modules.map_decoder import SMARTMapDecoder
from smart.modules.agent_decoder import SMARTAgentDecoder
================================================
FILE: smart/modules/agent_decoder.py
================================================
import pickle
from typing import Dict, Mapping, Optional
import torch
import torch.nn as nn
from smart.layers import MLPLayer
from smart.layers.attention_layer import AttentionLayer
from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
from torch_cluster import radius, radius_graph
from torch_geometric.data import Batch, HeteroData
from torch_geometric.utils import dense_to_sparse, subgraph
from smart.utils import angle_between_2d_vectors, weight_init, wrap_angle
import math
def cal_polygon_contour(x, y, theta, width, length):
left_front_x = x + 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_front_y = y + 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_front = (left_front_x, left_front_y)
right_front_x = x + 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_front_y = y + 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_front = (right_front_x, right_front_y)
right_back_x = x - 0.5 * length * math.cos(theta) + 0.5 * width * math.sin(theta)
right_back_y = y - 0.5 * length * math.sin(theta) - 0.5 * width * math.cos(theta)
right_back = (right_back_x, right_back_y)
left_back_x = x - 0.5 * length * math.cos(theta) - 0.5 * width * math.sin(theta)
left_back_y = y - 0.5 * length * math.sin(theta) + 0.5 * width * math.cos(theta)
left_back = (left_back_x, left_back_y)
polygon_contour = [left_front, right_front, right_back, left_back]
return polygon_contour
class SMARTAgentDecoder(nn.Module):
def __init__(self,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
time_span: Optional[int],
pl2a_radius: float,
a2a_radius: float,
num_freq_bands: int,
num_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
token_data: Dict,
token_size=512) -> None:
super(SMARTAgentDecoder, self).__init__()
self.dataset = dataset
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_historical_steps = num_historical_steps
self.time_span = time_span if time_span is not None else num_historical_steps
self.pl2a_radius = pl2a_radius
self.a2a_radius = a2a_radius
self.num_freq_bands = num_freq_bands
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
input_dim_x_a = 2
input_dim_r_t = 4
input_dim_r_pt2a = 3
input_dim_r_a2a = 3
input_dim_token = 8
self.type_a_emb = nn.Embedding(4, hidden_dim)
self.shape_emb = MLPLayer(3, hidden_dim, hidden_dim)
self.x_a_emb = FourierEmbedding(input_dim=input_dim_x_a, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
self.r_t_emb = FourierEmbedding(input_dim=input_dim_r_t, hidden_dim=hidden_dim, num_freq_bands=num_freq_bands)
self.r_pt2a_emb = FourierEmbedding(input_dim=input_dim_r_pt2a, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.r_a2a_emb = FourierEmbedding(input_dim=input_dim_r_a2a, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.token_emb_veh = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.token_emb_ped = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.token_emb_cyc = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.fusion_emb = MLPEmbedding(input_dim=self.hidden_dim * 2, hidden_dim=self.hidden_dim)
self.t_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.pt2a_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=True, has_pos_emb=True) for _ in range(num_layers)]
)
self.a2a_attn_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.token_size = token_size
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.token_size)
self.trajectory_token = token_data['token']
self.trajectory_token_traj = token_data['traj']
self.trajectory_token_all = token_data['token_all']
self.apply(weight_init)
self.shift = 5
self.beam_size = 5
self.hist_mask = True
def transform_rel(self, token_traj, prev_pos, prev_heading=None):
if prev_heading is None:
diff_xy = prev_pos[:, :, -1, :] - prev_pos[:, :, -2, :]
prev_heading = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
num_agent, num_step, traj_num, traj_dim = token_traj.shape
cos, sin = prev_heading.cos(), prev_heading.sin()
rot_mat = torch.zeros((num_agent, num_step, 2, 2), device=prev_heading.device)
rot_mat[:, :, 0, 0] = cos
rot_mat[:, :, 0, 1] = -sin
rot_mat[:, :, 1, 0] = sin
rot_mat[:, :, 1, 1] = cos
agent_diff_rel = torch.bmm(token_traj.view(-1, traj_num, 2), rot_mat.view(-1, 2, 2)).view(num_agent, num_step, traj_num, traj_dim)
agent_pred_rel = agent_diff_rel + prev_pos[:, :, -1:, :]
return agent_pred_rel
def agent_token_embedding(self, data, agent_category, agent_token_index, pos_a, head_vector_a, inference=False):
num_agent, num_step, traj_dim = pos_a.shape
motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),
pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
agent_type = data['agent']['type']
veh_mask = (agent_type == 0)
cyc_mask = (agent_type == 2)
ped_mask = (agent_type == 1)
trajectory_token_veh = torch.from_numpy(self.trajectory_token['veh']).clone().to(pos_a.device).to(torch.float)
self.agent_token_emb_veh = self.token_emb_veh(trajectory_token_veh.view(trajectory_token_veh.shape[0], -1))
trajectory_token_ped = torch.from_numpy(self.trajectory_token['ped']).clone().to(pos_a.device).to(torch.float)
self.agent_token_emb_ped = self.token_emb_ped(trajectory_token_ped.view(trajectory_token_ped.shape[0], -1))
trajectory_token_cyc = torch.from_numpy(self.trajectory_token['cyc']).clone().to(pos_a.device).to(torch.float)
self.agent_token_emb_cyc = self.token_emb_cyc(trajectory_token_cyc.view(trajectory_token_cyc.shape[0], -1))
if inference:
agent_token_traj_all = torch.zeros((num_agent, self.token_size, self.shift + 1, 4, 2), device=pos_a.device)
trajectory_token_all_veh = torch.from_numpy(self.trajectory_token_all['veh']).clone().to(pos_a.device).to(
torch.float)
trajectory_token_all_ped = torch.from_numpy(self.trajectory_token_all['ped']).clone().to(pos_a.device).to(
torch.float)
trajectory_token_all_cyc = torch.from_numpy(self.trajectory_token_all['cyc']).clone().to(pos_a.device).to(
torch.float)
agent_token_traj_all[veh_mask] = torch.cat(
[trajectory_token_all_veh[:, :self.shift], trajectory_token_veh[:, None, ...]], dim=1)
agent_token_traj_all[ped_mask] = torch.cat(
[trajectory_token_all_ped[:, :self.shift], trajectory_token_ped[:, None, ...]], dim=1)
agent_token_traj_all[cyc_mask] = torch.cat(
[trajectory_token_all_cyc[:, :self.shift], trajectory_token_cyc[:, None, ...]], dim=1)
agent_token_emb = torch.zeros((num_agent, num_step, self.hidden_dim), device=pos_a.device)
agent_token_emb[veh_mask] = self.agent_token_emb_veh[agent_token_index[veh_mask]]
agent_token_emb[ped_mask] = self.agent_token_emb_ped[agent_token_index[ped_mask]]
agent_token_emb[cyc_mask] = self.agent_token_emb_cyc[agent_token_index[cyc_mask]]
agent_token_traj = torch.zeros((num_agent, num_step, self.token_size, 4, 2), device=pos_a.device)
agent_token_traj[veh_mask] = trajectory_token_veh
agent_token_traj[ped_mask] = trajectory_token_ped
agent_token_traj[cyc_mask] = trajectory_token_cyc
vel = data['agent']['token_velocity']
categorical_embs = [
self.type_a_emb(data['agent']['type'].long()).repeat_interleave(repeats=num_step,
dim=0),
self.shape_emb(data['agent']['shape'][:, self.num_historical_steps - 1, :]).repeat_interleave(
repeats=num_step,
dim=0)
]
feature_a = torch.stack(
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),
], dim=-1)
x_a = self.x_a_emb(continuous_inputs=feature_a.view(-1, feature_a.size(-1)),
categorical_embs=categorical_embs)
x_a = x_a.view(-1, num_step, self.hidden_dim)
feat_a = torch.cat((agent_token_emb, x_a), dim=-1)
feat_a = self.fusion_emb(feat_a)
if inference:
return feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs
else:
return feat_a, agent_token_traj
def agent_predict_next(self, data, agent_category, feat_a):
num_agent, num_step, traj_dim = data['agent']['token_pos'].shape
agent_type = data['agent']['type']
veh_mask = (agent_type == 0) # * agent_category==3
cyc_mask = (agent_type == 2) # * agent_category==3
ped_mask = (agent_type == 1) # * agent_category==3
token_res = torch.zeros((num_agent, num_step, self.token_size), device=agent_category.device)
token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])
token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])
token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])
return token_res
def agent_predict_next_inf(self, data, agent_category, feat_a):
num_agent, traj_dim = feat_a.shape
agent_type = data['agent']['type']
veh_mask = (agent_type == 0) # * agent_category==3
cyc_mask = (agent_type == 2) # * agent_category==3
ped_mask = (agent_type == 1) # * agent_category==3
token_res = torch.zeros((num_agent, self.token_size), device=agent_category.device)
token_res[veh_mask] = self.token_predict_head(feat_a[veh_mask])
token_res[cyc_mask] = self.token_predict_cyc_head(feat_a[cyc_mask])
token_res[ped_mask] = self.token_predict_walker_head(feat_a[ped_mask])
return token_res
def build_temporal_edge(self, pos_a, head_a, head_vector_a, num_agent, mask, inference_mask=None):
pos_t = pos_a.reshape(-1, self.input_dim)
head_t = head_a.reshape(-1)
head_vector_t = head_vector_a.reshape(-1, 2)
hist_mask = mask.clone()
if self.hist_mask and self.training:
hist_mask[
torch.arange(mask.shape[0]).unsqueeze(1), torch.randint(0, mask.shape[1], (num_agent, 10))] = False
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
elif inference_mask is not None:
mask_t = hist_mask.unsqueeze(2) & inference_mask.unsqueeze(1)
else:
mask_t = hist_mask.unsqueeze(2) & hist_mask.unsqueeze(1)
edge_index_t = dense_to_sparse(mask_t)[0]
edge_index_t = edge_index_t[:, edge_index_t[1] > edge_index_t[0]]
edge_index_t = edge_index_t[:, edge_index_t[1] - edge_index_t[0] <= self.time_span / self.shift]
rel_pos_t = pos_t[edge_index_t[0]] - pos_t[edge_index_t[1]]
rel_head_t = wrap_angle(head_t[edge_index_t[0]] - head_t[edge_index_t[1]])
r_t = torch.stack(
[torch.norm(rel_pos_t[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_t[edge_index_t[1]], nbr_vector=rel_pos_t[:, :2]),
rel_head_t,
edge_index_t[0] - edge_index_t[1]], dim=-1)
r_t = self.r_t_emb(continuous_inputs=r_t, categorical_embs=None)
return edge_index_t, r_t
def build_interaction_edge(self, pos_a, head_a, head_vector_a, batch_s, mask_s):
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
edge_index_a2a = radius_graph(x=pos_s[:, :2], r=self.a2a_radius, batch=batch_s, loop=False,
max_num_neighbors=300)
edge_index_a2a = subgraph(subset=mask_s, edge_index=edge_index_a2a)[0]
rel_pos_a2a = pos_s[edge_index_a2a[0]] - pos_s[edge_index_a2a[1]]
rel_head_a2a = wrap_angle(head_s[edge_index_a2a[0]] - head_s[edge_index_a2a[1]])
r_a2a = torch.stack(
[torch.norm(rel_pos_a2a[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_a2a[1]], nbr_vector=rel_pos_a2a[:, :2]),
rel_head_a2a], dim=-1)
r_a2a = self.r_a2a_emb(continuous_inputs=r_a2a, categorical_embs=None)
return edge_index_a2a, r_a2a
def build_map2agent_edge(self, data, num_step, agent_category, pos_a, head_a, head_vector_a, mask,
batch_s, batch_pl):
mask_pl2a = mask.clone()
mask_pl2a = mask_pl2a.transpose(0, 1).reshape(-1)
pos_s = pos_a.transpose(0, 1).reshape(-1, self.input_dim)
head_s = head_a.transpose(0, 1).reshape(-1)
head_vector_s = head_vector_a.transpose(0, 1).reshape(-1, 2)
pos_pl = data['pt_token']['position'][:, :self.input_dim].contiguous()
orient_pl = data['pt_token']['orientation'].contiguous()
pos_pl = pos_pl.repeat(num_step, 1)
orient_pl = orient_pl.repeat(num_step)
edge_index_pl2a = radius(x=pos_s[:, :2], y=pos_pl[:, :2], r=self.pl2a_radius,
batch_x=batch_s, batch_y=batch_pl, max_num_neighbors=300)
edge_index_pl2a = edge_index_pl2a[:, mask_pl2a[edge_index_pl2a[1]]]
rel_pos_pl2a = pos_pl[edge_index_pl2a[0]] - pos_s[edge_index_pl2a[1]]
rel_orient_pl2a = wrap_angle(orient_pl[edge_index_pl2a[0]] - head_s[edge_index_pl2a[1]])
r_pl2a = torch.stack(
[torch.norm(rel_pos_pl2a[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_s[edge_index_pl2a[1]], nbr_vector=rel_pos_pl2a[:, :2]),
rel_orient_pl2a], dim=-1)
r_pl2a = self.r_pt2a_emb(continuous_inputs=r_pl2a, categorical_embs=None)
return edge_index_pl2a, r_pl2a
def forward(self,
data: HeteroData,
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
pos_a = data['agent']['token_pos']
head_a = data['agent']['token_heading']
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
num_agent, num_step, traj_dim = pos_a.shape
agent_category = data['agent']['category']
agent_token_index = data['agent']['token_idx']
feat_a, agent_token_traj = self.agent_token_embedding(data, agent_category, agent_token_index,
pos_a, head_vector_a)
agent_valid_mask = data['agent']['agent_valid_mask'].clone()
# eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
# agent_valid_mask[~eval_mask] = False
mask = agent_valid_mask
edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask)
if isinstance(data, Batch):
batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t
for t in range(num_step)], dim=0)
batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t
for t in range(num_step)], dim=0)
else:
batch_s = torch.arange(num_step,
device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])
batch_pl = torch.arange(num_step,
device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
mask_s = mask.transpose(0, 1).reshape(-1)
edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a, batch_s, mask_s)
mask[agent_category != 3] = False
edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,
head_vector_a, mask, batch_s, batch_pl)
for i in range(self.num_layers):
feat_a = feat_a.reshape(-1, self.hidden_dim)
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
feat_a = feat_a.reshape(-1, num_step,
self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(
repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
num_agent, num_step, hidden_dim, traj_num, traj_dim = agent_token_traj.shape
next_token_prob = self.token_predict_head(feat_a)
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
next_token_index_gt = agent_token_index.roll(shifts=-1, dims=1)
next_token_eval_mask = mask.clone()
next_token_eval_mask = next_token_eval_mask * next_token_eval_mask.roll(shifts=-1, dims=1) * next_token_eval_mask.roll(shifts=1, dims=1)
next_token_eval_mask[:, -1] = False
return {'x_a': feat_a,
'next_token_idx': next_token_idx,
'next_token_prob': next_token_prob,
'next_token_idx_gt': next_token_index_gt,
'next_token_eval_mask': next_token_eval_mask,
}
def inference(self,
data: HeteroData,
map_enc: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
eval_mask = data['agent']['valid_mask'][:, self.num_historical_steps - 1]
pos_a = data['agent']['token_pos'].clone()
head_a = data['agent']['token_heading'].clone()
num_agent, num_step, traj_dim = pos_a.shape
pos_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
head_a[:, (self.num_historical_steps - 1) // self.shift:] = 0
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
agent_valid_mask = data['agent']['agent_valid_mask'].clone()
agent_valid_mask[:, (self.num_historical_steps - 1) // self.shift:] = True
agent_valid_mask[~eval_mask] = False
agent_token_index = data['agent']['token_idx']
agent_category = data['agent']['category']
feat_a, agent_token_traj, agent_token_traj_all, agent_token_emb, categorical_embs = self.agent_token_embedding(
data,
agent_category,
agent_token_index,
pos_a,
head_vector_a,
inference=True)
agent_type = data["agent"]["type"]
veh_mask = (agent_type == 0) # * agent_category==3
cyc_mask = (agent_type == 2) # * agent_category==3
ped_mask = (agent_type == 1) # * agent_category==3
av_mask = data["agent"]["av_index"]
self.num_recurrent_steps_val = data["agent"]['position'].shape[1]-self.num_historical_steps
pred_traj = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, 2, device=feat_a.device)
pred_head = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val, device=feat_a.device)
pred_prob = torch.zeros(data["agent"].num_nodes, self.num_recurrent_steps_val // self.shift, device=feat_a.device)
next_token_idx_list = []
mask = agent_valid_mask.clone()
feat_a_t_dict = {}
for t in range(self.num_recurrent_steps_val // self.shift):
if t == 0:
inference_mask = mask.clone()
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t:] = False
else:
inference_mask = torch.zeros_like(mask)
inference_mask[:, (self.num_historical_steps - 1) // self.shift + t - 1] = True
edge_index_t, r_t = self.build_temporal_edge(pos_a, head_a, head_vector_a, num_agent, mask, inference_mask)
if isinstance(data, Batch):
batch_s = torch.cat([data['agent']['batch'] + data.num_graphs * t
for t in range(num_step)], dim=0)
batch_pl = torch.cat([data['pt_token']['batch'] + data.num_graphs * t
for t in range(num_step)], dim=0)
else:
batch_s = torch.arange(num_step,
device=pos_a.device).repeat_interleave(data['agent']['num_nodes'])
batch_pl = torch.arange(num_step,
device=pos_a.device).repeat_interleave(data['pt_token']['num_nodes'])
# In the inference stage, we only infer the current stage for recurrent
edge_index_pl2a, r_pl2a = self.build_map2agent_edge(data, num_step, agent_category, pos_a, head_a,
head_vector_a,
inference_mask, batch_s,
batch_pl)
mask_s = inference_mask.transpose(0, 1).reshape(-1)
edge_index_a2a, r_a2a = self.build_interaction_edge(pos_a, head_a, head_vector_a,
batch_s, mask_s)
for i in range(self.num_layers):
if i in feat_a_t_dict:
feat_a = feat_a_t_dict[i]
feat_a = feat_a.reshape(-1, self.hidden_dim)
feat_a = self.t_attn_layers[i](feat_a, r_t, edge_index_t)
feat_a = feat_a.reshape(-1, num_step,
self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim)
feat_a = self.pt2a_attn_layers[i]((map_enc['x_pt'].repeat_interleave(
repeats=num_step, dim=0).reshape(-1, num_step, self.hidden_dim).transpose(0, 1).reshape(
-1, self.hidden_dim), feat_a), r_pl2a, edge_index_pl2a)
feat_a = self.a2a_attn_layers[i](feat_a, r_a2a, edge_index_a2a)
feat_a = feat_a.reshape(num_step, -1, self.hidden_dim).transpose(0, 1)
if i+1 not in feat_a_t_dict:
feat_a_t_dict[i+1] = feat_a
else:
feat_a_t_dict[i+1][:, (self.num_historical_steps - 1) // self.shift - 1 + t] = feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
next_token_prob = self.token_predict_head(feat_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t])
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
topk_prob, next_token_idx = torch.topk(next_token_prob_softmax, k=self.beam_size, dim=-1)
expanded_index = next_token_idx[..., None, None, None].expand(-1, -1, 6, 4, 2)
next_token_traj = torch.gather(agent_token_traj_all, 1, expanded_index)
theta = head_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t]
cos, sin = theta.cos(), theta.sin()
rot_mat = torch.zeros((num_agent, 2, 2), device=theta.device)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
agent_diff_rel = torch.bmm(next_token_traj.view(-1, 4, 2),
rot_mat[:, None, None, ...].repeat(1, self.beam_size, self.shift + 1, 1, 1).view(
-1, 2, 2)).view(num_agent, self.beam_size, self.shift + 1, 4, 2)
agent_pred_rel = agent_diff_rel + pos_a[:, (self.num_historical_steps - 1) // self.shift - 1 + t, :][:, None, None, None, ...]
sample_index = torch.multinomial(topk_prob, 1).to(agent_pred_rel.device)
agent_pred_rel = agent_pred_rel.gather(dim=1,
index=sample_index[..., None, None, None].expand(-1, -1, 6, 4,
2))[:, 0, ...]
pred_prob[:, t] = topk_prob.gather(dim=-1, index=sample_index)[:, 0]
pred_traj[:, t * 5:(t + 1) * 5] = agent_pred_rel[:, 1:, ...].clone().mean(dim=2)
diff_xy = agent_pred_rel[:, 1:, 0, :] - agent_pred_rel[:, 1:, 3, :]
pred_head[:, t * 5:(t + 1) * 5] = torch.arctan2(diff_xy[:, :, 1], diff_xy[:, :, 0])
pos_a[:, (self.num_historical_steps - 1) // self.shift + t] = agent_pred_rel[:, -1, ...].clone().mean(dim=1)
diff_xy = agent_pred_rel[:, -1, 0, :] - agent_pred_rel[:, -1, 3, :]
theta = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])
head_a[:, (self.num_historical_steps - 1) // self.shift + t] = theta
next_token_idx = next_token_idx.gather(dim=1, index=sample_index)
next_token_idx = next_token_idx.squeeze(-1)
next_token_idx_list.append(next_token_idx[:, None])
agent_token_emb[veh_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_veh[
next_token_idx[veh_mask]]
agent_token_emb[ped_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_ped[
next_token_idx[ped_mask]]
agent_token_emb[cyc_mask, (self.num_historical_steps - 1) // self.shift + t] = self.agent_token_emb_cyc[
next_token_idx[cyc_mask]]
motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),
pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
vel = motion_vector_a.clone() / (0.1 * self.shift)
vel[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0
motion_vector_a[:, (self.num_historical_steps - 1) // self.shift + 1 + t:] = 0
x_a = torch.stack(
[torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2])], dim=-1)
x_a = self.x_a_emb(continuous_inputs=x_a.view(-1, x_a.size(-1)),
categorical_embs=categorical_embs)
x_a = x_a.view(-1, num_step, self.hidden_dim)
feat_a = torch.cat((agent_token_emb, x_a), dim=-1)
feat_a = self.fusion_emb(feat_a)
agent_valid_mask[agent_category != 3] = False
return {
'pos_a': pos_a[:, (self.num_historical_steps - 1) // self.shift:],
'head_a': head_a[:, (self.num_historical_steps - 1) // self.shift:],
'gt': data['agent']['position'][:, self.num_historical_steps:, :self.input_dim].contiguous(),
'valid_mask': agent_valid_mask[:, self.num_historical_steps:],
'pred_traj': pred_traj,
'pred_head': pred_head,
'next_token_idx': torch.cat(next_token_idx_list, dim=-1),
'next_token_idx_gt': agent_token_index.roll(shifts=-1, dims=1),
'next_token_eval_mask': data['agent']['agent_valid_mask'],
'pred_prob': pred_prob,
'vel': vel
}
================================================
FILE: smart/modules/map_decoder.py
================================================
import os.path
from typing import Dict
import torch
import torch.nn as nn
from torch_cluster import radius_graph
from torch_geometric.data import Batch
from torch_geometric.data import HeteroData
from torch_geometric.utils import dense_to_sparse, subgraph
from smart.utils.nan_checker import check_nan_inf
from smart.layers.attention_layer import AttentionLayer
from smart.layers import MLPLayer
from smart.layers.fourier_embedding import FourierEmbedding, MLPEmbedding
from smart.utils import angle_between_2d_vectors
from smart.utils import merge_edges
from smart.utils import weight_init
from smart.utils import wrap_angle
import pickle
class SMARTMapDecoder(nn.Module):
def __init__(self,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
pl2pl_radius: float,
num_freq_bands: int,
num_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
map_token) -> None:
super(SMARTMapDecoder, self).__init__()
self.dataset = dataset
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_historical_steps = num_historical_steps
self.pl2pl_radius = pl2pl_radius
self.num_freq_bands = num_freq_bands
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
if input_dim == 2:
input_dim_r_pt2pt = 3
elif input_dim == 3:
input_dim_r_pt2pt = 4
else:
raise ValueError('{} is not a valid dimension'.format(input_dim))
self.type_pt_emb = nn.Embedding(17, hidden_dim)
self.side_pt_emb = nn.Embedding(4, hidden_dim)
self.polygon_type_emb = nn.Embedding(4, hidden_dim)
self.light_pl_emb = nn.Embedding(4, hidden_dim)
self.r_pt2pt_emb = FourierEmbedding(input_dim=input_dim_r_pt2pt, hidden_dim=hidden_dim,
num_freq_bands=num_freq_bands)
self.pt2pt_layers = nn.ModuleList(
[AttentionLayer(hidden_dim=hidden_dim, num_heads=num_heads, head_dim=head_dim, dropout=dropout,
bipartite=False, has_pos_emb=True) for _ in range(num_layers)]
)
self.token_size = 1024
self.token_predict_head = MLPLayer(input_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=self.token_size)
input_dim_token = 22
self.token_emb = MLPEmbedding(input_dim=input_dim_token, hidden_dim=hidden_dim)
self.map_token = map_token
self.apply(weight_init)
self.mask_pt = False
def maybe_autocast(self, dtype=torch.float32):
return torch.cuda.amp.autocast(dtype=dtype)
def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
pt_valid_mask = data['pt_token']['pt_valid_mask']
pt_pred_mask = data['pt_token']['pt_pred_mask']
pt_target_mask = data['pt_token']['pt_target_mask']
mask_s = pt_valid_mask
pos_pt = data['pt_token']['position'][:, :self.input_dim].contiguous()
orient_pt = data['pt_token']['orientation'].contiguous()
orient_vector_pt = torch.stack([orient_pt.cos(), orient_pt.sin()], dim=-1)
token_sample_pt = self.map_token['traj_src'].to(pos_pt.device).to(torch.float)
pt_token_emb_src = self.token_emb(token_sample_pt.view(token_sample_pt.shape[0], -1))
pt_token_emb = pt_token_emb_src[data['pt_token']['token_idx']]
if self.input_dim == 2:
x_pt = pt_token_emb
elif self.input_dim == 3:
x_pt = pt_token_emb
else:
raise ValueError('{} is not a valid dimension'.format(self.input_dim))
token2pl = data[('pt_token', 'to', 'map_polygon')]['edge_index']
token_light_type = data['map_polygon']['light_type'][token2pl[1]]
x_pt_categorical_embs = [self.type_pt_emb(data['pt_token']['type'].long()),
self.polygon_type_emb(data['pt_token']['pl_type'].long()),
self.light_pl_emb(token_light_type.long()),]
x_pt = x_pt + torch.stack(x_pt_categorical_embs).sum(dim=0)
edge_index_pt2pt = radius_graph(x=pos_pt[:, :2], r=self.pl2pl_radius,
batch=data['pt_token']['batch'] if isinstance(data, Batch) else None,
loop=False, max_num_neighbors=100)
if self.mask_pt:
edge_index_pt2pt = subgraph(subset=mask_s, edge_index=edge_index_pt2pt)[0]
rel_pos_pt2pt = pos_pt[edge_index_pt2pt[0]] - pos_pt[edge_index_pt2pt[1]]
rel_orient_pt2pt = wrap_angle(orient_pt[edge_index_pt2pt[0]] - orient_pt[edge_index_pt2pt[1]])
if self.input_dim == 2:
r_pt2pt = torch.stack(
[torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
nbr_vector=rel_pos_pt2pt[:, :2]),
rel_orient_pt2pt], dim=-1)
elif self.input_dim == 3:
r_pt2pt = torch.stack(
[torch.norm(rel_pos_pt2pt[:, :2], p=2, dim=-1),
angle_between_2d_vectors(ctr_vector=orient_vector_pt[edge_index_pt2pt[1]],
nbr_vector=rel_pos_pt2pt[:, :2]),
rel_pos_pt2pt[:, -1],
rel_orient_pt2pt], dim=-1)
else:
raise ValueError('{} is not a valid dimension'.format(self.input_dim))
r_pt2pt = self.r_pt2pt_emb(continuous_inputs=r_pt2pt, categorical_embs=None)
for i in range(self.num_layers):
x_pt = self.pt2pt_layers[i](x_pt, r_pt2pt, edge_index_pt2pt)
next_token_prob = self.token_predict_head(x_pt[pt_pred_mask])
next_token_prob_softmax = torch.softmax(next_token_prob, dim=-1)
_, next_token_idx = torch.topk(next_token_prob_softmax, k=10, dim=-1)
next_token_index_gt = data['pt_token']['token_idx'][pt_target_mask]
return {
'x_pt': x_pt,
'map_next_token_idx': next_token_idx,
'map_next_token_prob': next_token_prob,
'map_next_token_idx_gt': next_token_index_gt,
'map_next_token_eval_mask': pt_pred_mask[pt_pred_mask]
}
================================================
FILE: smart/modules/smart_decoder.py
================================================
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch_geometric.data import HeteroData
from smart.modules.agent_decoder import SMARTAgentDecoder
from smart.modules.map_decoder import SMARTMapDecoder
class SMARTDecoder(nn.Module):
def __init__(self,
dataset: str,
input_dim: int,
hidden_dim: int,
num_historical_steps: int,
pl2pl_radius: float,
time_span: Optional[int],
pl2a_radius: float,
a2a_radius: float,
num_freq_bands: int,
num_map_layers: int,
num_agent_layers: int,
num_heads: int,
head_dim: int,
dropout: float,
map_token: Dict,
token_data: Dict,
use_intention=False,
token_size=512) -> None:
super(SMARTDecoder, self).__init__()
self.map_encoder = SMARTMapDecoder(
dataset=dataset,
input_dim=input_dim,
hidden_dim=hidden_dim,
num_historical_steps=num_historical_steps,
pl2pl_radius=pl2pl_radius,
num_freq_bands=num_freq_bands,
num_layers=num_map_layers,
num_heads=num_heads,
head_dim=head_dim,
dropout=dropout,
map_token=map_token
)
self.agent_encoder = SMARTAgentDecoder(
dataset=dataset,
input_dim=input_dim,
hidden_dim=hidden_dim,
num_historical_steps=num_historical_steps,
time_span=time_span,
pl2a_radius=pl2a_radius,
a2a_radius=a2a_radius,
num_freq_bands=num_freq_bands,
num_layers=num_agent_layers,
num_heads=num_heads,
head_dim=head_dim,
dropout=dropout,
token_size=token_size,
token_data=token_data
)
self.map_enc = None
def forward(self, data: HeteroData) -> Dict[str, torch.Tensor]:
map_enc = self.map_encoder(data)
agent_enc = self.agent_encoder(data, map_enc)
return {**map_enc, **agent_enc}
def inference(self, data: HeteroData) -> Dict[str, torch.Tensor]:
map_enc = self.map_encoder(data)
agent_enc = self.agent_encoder.inference(data, map_enc)
return {**map_enc, **agent_enc}
def inference_no_map(self, data: HeteroData, map_enc) -> Dict[str, torch.Tensor]:
agent_enc = self.agent_encoder.inference(data, map_enc)
return {**map_enc, **agent_enc}
================================================
FILE: smart/preprocess/__init__.py
================================================
================================================
FILE: smart/preprocess/preprocess.py
================================================
import numpy as np
import pandas as pd
import os
import torch
from typing import Any, Dict, List, Optional
predict_unseen_agents = False
vector_repr = True
_agent_types = ['vehicle', 'pedestrian', 'cyclist', 'background']
_polygon_types = ['VEHICLE', 'BIKE', 'BUS', 'PEDESTRIAN']
_polygon_light_type = ['LANE_STATE_STOP', 'LANE_STATE_GO', 'LANE_STATE_CAUTION', 'LANE_STATE_UNKNOWN']
_point_types = ['DASH_SOLID_YELLOW', 'DASH_SOLID_WHITE', 'DASHED_WHITE', 'DASHED_YELLOW',
'DOUBLE_SOLID_YELLOW', 'DOUBLE_SOLID_WHITE', 'DOUBLE_DASH_YELLOW', 'DOUBLE_DASH_WHITE',
'SOLID_YELLOW', 'SOLID_WHITE', 'SOLID_DASH_WHITE', 'SOLID_DASH_YELLOW', 'EDGE',
'NONE', 'UNKNOWN', 'CROSSWALK', 'CENTERLINE']
_point_sides = ['LEFT', 'RIGHT', 'CENTER']
_polygon_to_polygon_types = ['NONE', 'PRED', 'SUCC', 'LEFT', 'RIGHT']
_polygon_is_intersections = [True, False, None]
Lane_type_hash = {
4: "BIKE",
3: "VEHICLE",
2: "VEHICLE",
1: "BUS"
}
boundary_type_hash = {
5: "UNKNOWN",
6: "DASHED_WHITE",
7: "SOLID_WHITE",
8: "DOUBLE_DASH_WHITE",
9: "DASHED_YELLOW",
10: "DOUBLE_DASH_YELLOW",
11: "SOLID_YELLOW",
12: "DOUBLE_SOLID_YELLOW",
13: "DASH_SOLID_YELLOW",
14: "UNKNOWN",
15: "EDGE",
16: "EDGE"
}
def get_agent_features(df: pd.DataFrame, av_id, num_historical_steps=10, dim=3, num_steps=91) -> Dict[str, Any]:
if not predict_unseen_agents: # filter out agents that are unseen during the historical time steps
historical_df = df[df['timestep'] == num_historical_steps-1]
agent_ids = list(historical_df['track_id'].unique())
df = df[df['track_id'].isin(agent_ids)]
else:
agent_ids = list(df['track_id'].unique())
num_agents = len(agent_ids)
# initialization
valid_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
current_valid_mask = torch.zeros(num_agents, dtype=torch.bool)
predict_mask = torch.zeros(num_agents, num_steps, dtype=torch.bool)
agent_id: List[Optional[str]] = [None] * num_agents
agent_type = torch.zeros(num_agents, dtype=torch.uint8)
agent_category = torch.zeros(num_agents, dtype=torch.uint8)
position = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
heading = torch.zeros(num_agents, num_steps, dtype=torch.float)
velocity = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
shape = torch.zeros(num_agents, num_steps, dim, dtype=torch.float)
for track_id, track_df in df.groupby('track_id'):
agent_idx = agent_ids.index(track_id)
agent_steps = track_df['timestep'].values
valid_mask[agent_idx, agent_steps] = True
current_valid_mask[agent_idx] = valid_mask[agent_idx, num_historical_steps - 1]
predict_mask[agent_idx, agent_steps] = True
if vector_repr: # a time step t is valid only when both t and t-1 are valid
valid_mask[agent_idx, 1: num_historical_steps] = (
valid_mask[agent_idx, :num_historical_steps - 1] &
valid_mask[agent_idx, 1: num_historical_steps])
valid_mask[agent_idx, 0] = False
predict_mask[agent_idx, :num_historical_steps] = False
if not current_valid_mask[agent_idx]:
predict_mask[agent_idx, num_historical_steps:] = False
agent_id[agent_idx] = track_id
agent_type[agent_idx] = _agent_types.index(track_df['object_type'].values[0])
agent_category[agent_idx] = track_df['object_category'].values[0]
position[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['position_x'].values,
track_df['position_y'].values,
track_df['position_z'].values],
axis=-1)).float()
heading[agent_idx, agent_steps] = torch.from_numpy(track_df['heading'].values).float()
velocity[agent_idx, agent_steps, :2] = torch.from_numpy(np.stack([track_df['velocity_x'].values,
track_df['velocity_y'].values],
axis=-1)).float()
shape[agent_idx, agent_steps, :3] = torch.from_numpy(np.stack([track_df['length'].values,
track_df['width'].values,
track_df["height"].values],
axis=-1)).float()
av_idx = agent_id.index(av_id)
return {
'num_nodes': num_agents,
'av_index': av_idx,
'valid_mask': valid_mask,
'predict_mask': predict_mask,
'id': agent_id,
'type': agent_type,
'category': agent_category,
'position': position,
'heading': heading,
'velocity': velocity,
'shape': shape
}
================================================
FILE: smart/tokens/__init__.py
================================================
================================================
FILE: smart/transforms/__init__.py
================================================
from smart.transforms.target_builder import WaymoTargetBuilder
================================================
FILE: smart/transforms/target_builder.py
================================================
import numpy as np
import torch
from torch_geometric.data import HeteroData
from torch_geometric.transforms import BaseTransform
from smart.utils import wrap_angle
from smart.utils.log import Logging
def to_16(data):
if isinstance(data, dict):
for key, value in data.items():
new_value = to_16(value)
data[key] = new_value
if isinstance(data, torch.Tensor):
if data.dtype == torch.float32:
data = data.to(torch.float16)
return data
def tofloat32(data):
for name in data:
value = data[name]
if isinstance(value, dict):
value = tofloat32(value)
elif isinstance(value, torch.Tensor) and value.dtype == torch.float64:
value = value.to(torch.float32)
data[name] = value
return data
class WaymoTargetBuilder(BaseTransform):
def __init__(self,
num_historical_steps: int,
num_future_steps: int,
mode="train") -> None:
self.num_historical_steps = num_historical_steps
self.num_future_steps = num_future_steps
self.mode = mode
self.num_features = 3
self.augment = False
self.logger = Logging().log(level='DEBUG')
def score_ego_agent(self, agent):
av_index = agent['av_index']
agent["category"][av_index] = 5
return agent
def clip(self, agent, max_num=32):
av_index = agent["av_index"]
valid = agent['valid_mask']
ego_pos = agent["position"][av_index]
obstacle_mask = agent['type'] == 3
distance = torch.norm(agent["position"][:, self.num_historical_steps-1, :2] - ego_pos[self.num_historical_steps-1, :2], dim=-1) # keep the closest 100 vehicles near the ego car
distance[obstacle_mask] = 10e5
sort_idx = distance.sort()[1]
mask = torch.zeros(valid.shape[0])
mask[sort_idx[:max_num]] = 1
mask = mask.to(torch.bool)
mask[av_index] = True
new_av_index = mask[:av_index].sum()
agent["num_nodes"] = int(mask.sum())
agent["av_index"] = int(new_av_index)
excluded = ["num_nodes", "av_index", "ego"]
for key, val in agent.items():
if key in excluded:
continue
if key == "id":
val = list(np.array(val)[mask])
agent[key] = val
continue
if len(val.size()) > 1:
agent[key] = val[mask, ...]
else:
agent[key] = val[mask]
return agent
def score_nearby_vehicle(self, agent, max_num=10):
av_index = agent['av_index']
agent["category"] = torch.zeros_like(agent["category"])
obstacle_mask = agent['type'] == 3
pos = agent["position"][av_index, self.num_historical_steps, :2]
distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1)
distance[obstacle_mask] = 10e5
sort_idx = distance.sort()[1]
nearby_mask = torch.zeros(distance.shape[0])
nearby_mask[sort_idx[1:max_num]] = 1
nearby_mask = nearby_mask.bool()
agent["category"][nearby_mask] = 3
agent["category"][obstacle_mask] = 0
def score_trained_vehicle(self, agent, max_num=10, min_distance=0):
av_index = agent['av_index']
agent["category"] = torch.zeros_like(agent["category"])
pos = agent["position"][av_index, self.num_historical_steps, :2]
distance = torch.norm(agent["position"][:, self.num_historical_steps, :2] - pos, dim=-1)
distance_all_time = torch.norm(agent["position"][:, :, :2] - agent["position"][av_index, :, :2], dim=-1)
invalid_mask = distance_all_time < 150 # we do not believe the perception out of range of 150 meters
agent["valid_mask"] = agent["valid_mask"] * invalid_mask
# we do not predict vehicle too far away from ego car
closet_vehicle = distance < 100
valid = agent['valid_mask']
valid_current = valid[:, (self.num_historical_steps):]
valid_counts = valid_current.sum(1)
counts_vehicle = valid_counts >= 1
no_backgroud = agent['type'] != 3
vehicle2pred = closet_vehicle & counts_vehicle & no_backgroud
if vehicle2pred.sum() > max_num:
# too many still vehicle so that train the model using the moving vehicle as much as possible
true_indices = torch.nonzero(vehicle2pred).squeeze(1)
selected_indices = true_indices[torch.randperm(true_indices.size(0))[:max_num]]
vehicle2pred.fill_(False)
vehicle2pred[selected_indices] = True
agent["category"][vehicle2pred] = 3
def rotate_agents(self, position, heading, num_nodes, num_historical_steps, num_future_steps):
origin = position[:, num_historical_steps - 1]
theta = heading[:, num_historical_steps - 1]
cos, sin = theta.cos(), theta.sin()
rot_mat = theta.new_zeros(num_nodes, 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = -sin
rot_mat[:, 1, 0] = sin
rot_mat[:, 1, 1] = cos
target = origin.new_zeros(num_nodes, num_future_steps, 4)
target[..., :2] = torch.bmm(position[:, num_historical_steps:, :2] -
origin[:, :2].unsqueeze(1), rot_mat)
his = origin.new_zeros(num_nodes, num_historical_steps, 4)
his[..., :2] = torch.bmm(position[:, :num_historical_steps, :2] -
origin[:, :2].unsqueeze(1), rot_mat)
if position.size(2) == 3:
target[..., 2] = (position[:, num_historical_steps:, 2] -
origin[:, 2].unsqueeze(-1))
his[..., 2] = (position[:, :num_historical_steps, 2] -
origin[:, 2].unsqueeze(-1))
target[..., 3] = wrap_angle(heading[:, num_historical_steps:] -
theta.unsqueeze(-1))
his[..., 3] = wrap_angle(heading[:, :num_historical_steps] -
theta.unsqueeze(-1))
else:
target[..., 2] = wrap_angle(heading[:, num_historical_steps:] -
theta.unsqueeze(-1))
his[..., 2] = wrap_angle(heading[:, :num_historical_steps] -
theta.unsqueeze(-1))
return his, target
def __call__(self, data) -> HeteroData:
agent = data["agent"]
self.score_ego_agent(agent)
self.score_trained_vehicle(agent, max_num=32)
return HeteroData(data)
================================================
FILE: smart/utils/__init__.py
================================================
from smart.utils.geometry import angle_between_2d_vectors
from smart.utils.geometry import angle_between_3d_vectors
from smart.utils.geometry import side_to_directed_lineseg
from smart.utils.geometry import wrap_angle
from smart.utils.graph import add_edges
from smart.utils.graph import bipartite_dense_to_sparse
from smart.utils.graph import complete_graph
from smart.utils.graph import merge_edges
from smart.utils.graph import unbatch
from smart.utils.list import safe_list_index
from smart.utils.weight_init import weight_init
================================================
FILE: smart/utils/cluster_reader.py
================================================
import io
import pickle
import pandas as pd
import json
class LoadScenarioFromCeph:
def __init__(self):
from petrel_client.client import Client
self.file_client = Client('~/petreloss.conf')
def list(self, dir_path):
return list(self.file_client.list(dir_path))
def save(self, data, url):
self.file_client.put(url, pickle.dumps(data))
def read_correct_csv(self, scenario_path):
output = pd.read_csv(io.StringIO(self.file_client.get(scenario_path).decode('utf-8')), engine="python")
return output
def contains(self, url):
return self.file_client.contains(url)
def read_string(self, csv_url):
from io import StringIO
df = pd.read_csv(StringIO(str(self.file_client.get(csv_url), 'utf-8')), sep='\s+', low_memory=False)
return df
def read(self, scenario_path):
with io.BytesIO(self.file_client.get(scenario_path)) as f:
datas = pickle.load(f)
return datas
def read_json(self, path):
with io.BytesIO(self.file_client.get(path)) as f:
data = json.load(f)
return data
def read_csv(self, scenario_path):
return pickle.loads(self.file_client.get(scenario_path))
def read_model(self, model_path):
with io.BytesIO(self.file_client.get(model_path)) as f:
pass
================================================
FILE: smart/utils/config.py
================================================
import os
import yaml
import easydict
def load_config_act(path):
""" load config file"""
with open(path, 'r') as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
return easydict.EasyDict(cfg)
def load_config_init(path):
""" load config file"""
path = os.path.join('init/configs', f'{path}.yaml')
with open(path, 'r') as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
return cfg
================================================
FILE: smart/utils/geometry.py
================================================
import math
import torch
def angle_between_2d_vectors(
ctr_vector: torch.Tensor,
nbr_vector: torch.Tensor) -> torch.Tensor:
return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0],
(ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1))
def angle_between_3d_vectors(
ctr_vector: torch.Tensor,
nbr_vector: torch.Tensor) -> torch.Tensor:
return torch.atan2(torch.cross(ctr_vector, nbr_vector, dim=-1).norm(p=2, dim=-1),
(ctr_vector * nbr_vector).sum(dim=-1))
def side_to_directed_lineseg(
query_point: torch.Tensor,
start_point: torch.Tensor,
end_point: torch.Tensor) -> str:
cond = ((end_point[0] - start_point[0]) * (query_point[1] - start_point[1]) -
(end_point[1] - start_point[1]) * (query_point[0] - start_point[0]))
if cond > 0:
return 'LEFT'
elif cond < 0:
return 'RIGHT'
else:
return 'CENTER'
def wrap_angle(
angle: torch.Tensor,
min_val: float = -math.pi,
max_val: float = math.pi) -> torch.Tensor:
return min_val + (angle + max_val) % (max_val - min_val)
================================================
FILE: smart/utils/graph.py
================================================
from typing import List, Optional, Tuple, Union
import torch
from torch_geometric.utils import coalesce
from torch_geometric.utils import degree
def add_edges(
from_edge_index: torch.Tensor,
to_edge_index: torch.Tensor,
from_edge_attr: Optional[torch.Tensor] = None,
to_edge_attr: Optional[torch.Tensor] = None,
replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype)
mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) &
(to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0)))
if replace:
to_mask = mask.any(dim=1)
if from_edge_attr is not None and to_edge_attr is not None:
from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0)
to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1)
else:
from_mask = mask.any(dim=0)
if from_edge_attr is not None and to_edge_attr is not None:
from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype)
to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0)
to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1)
return to_edge_index, to_edge_attr
def merge_edges(
edge_indices: List[torch.Tensor],
edge_attrs: Optional[List[torch.Tensor]] = None,
reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
edge_index = torch.cat(edge_indices, dim=1)
if edge_attrs is not None:
edge_attr = torch.cat(edge_attrs, dim=0)
else:
edge_attr = None
return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce)
def complete_graph(
num_nodes: Union[int, Tuple[int, int]],
ptr: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
loop: bool = False,
device: Optional[Union[torch.device, str]] = None) -> torch.Tensor:
if ptr is None:
if isinstance(num_nodes, int):
num_src, num_dst = num_nodes, num_nodes
else:
num_src, num_dst = num_nodes
edge_index = torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
torch.arange(num_dst, dtype=torch.long, device=device)).t()
else:
if isinstance(ptr, torch.Tensor):
ptr_src, ptr_dst = ptr, ptr
num_src_batch = num_dst_batch = ptr[1:] - ptr[:-1]
else:
ptr_src, ptr_dst = ptr
num_src_batch = ptr_src[1:] - ptr_src[:-1]
num_dst_batch = ptr_dst[1:] - ptr_dst[:-1]
edge_index = torch.cat(
[torch.cartesian_prod(torch.arange(num_src, dtype=torch.long, device=device),
torch.arange(num_dst, dtype=torch.long, device=device)) + p
for num_src, num_dst, p in zip(num_src_batch, num_dst_batch, torch.stack([ptr_src, ptr_dst], dim=1))],
dim=0)
edge_index = edge_index.t()
if isinstance(num_nodes, int) and not loop:
edge_index = edge_index[:, edge_index[0] != edge_index[1]]
return edge_index.contiguous()
def bipartite_dense_to_sparse(adj: torch.Tensor) -> torch.Tensor:
index = adj.nonzero(as_tuple=True)
if len(index) == 3:
batch_src = index[0] * adj.size(1)
batch_dst = index[0] * adj.size(2)
index = (batch_src + index[1], batch_dst + index[2])
return torch.stack(index, dim=0)
def unbatch(
src: torch.Tensor,
batch: torch.Tensor,
dim: int = 0) -> List[torch.Tensor]:
sizes = degree(batch, dtype=torch.long).tolist()
return src.split(sizes, dim)
================================================
FILE: smart/utils/list.py
================================================
from typing import Any, List, Optional
def safe_list_index(ls: List[Any], elem: Any) -> Optional[int]:
try:
return ls.index(elem)
except ValueError:
return None
================================================
FILE: smart/utils/log.py
================================================
import logging
import time
import os
class Logging:
def make_log_dir(self, dirname='logs'):
now_dir = os.path.dirname(__file__)
path = os.path.join(now_dir, dirname)
path = os.path.normpath(path)
if not os.path.exists(path):
os.mkdir(path)
return path
def get_log_filename(self):
filename = "{}.log".format(time.strftime("%Y-%m-%d",time.localtime()))
filename = os.path.join(self.make_log_dir(), filename)
filename = os.path.normpath(filename)
return filename
def log(self, level='DEBUG', name="simagent"):
logger = logging.getLogger(name)
level = getattr(logging, level)
logger.setLevel(level)
if not logger.handlers:
sh = logging.StreamHandler()
fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8")
fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s")
sh.setFormatter(fmt=fmt)
fh.setFormatter(fmt=fmt)
logger.addHandler(sh)
logger.addHandler(fh)
return logger
def add_log(self, logger, level='DEBUG'):
level = getattr(logging, level)
logger.setLevel(level)
if not logger.handlers:
sh = logging.StreamHandler()
fh = logging.FileHandler(filename=self.get_log_filename(), mode='a',encoding="utf-8")
fmt = logging.Formatter("%(asctime)s-%(levelname)s-%(filename)s-Line:%(lineno)d-Message:%(message)s")
sh.setFormatter(fmt=fmt)
fh.setFormatter(fmt=fmt)
logger.addHandler(sh)
logger.addHandler(fh)
return logger
if __name__ == '__main__':
logger = Logging().log(level='INFO')
logger.debug("1111111111111111111111") #使用日志器生成日志
logger.info("222222222222222222222222")
logger.error("附件为IP飞机外婆家二分IP文件放")
logger.warning("3333333333333333333333333333")
logger.critical("44444444444444444444444444")
================================================
FILE: smart/utils/nan_checker.py
================================================
import torch
def check_nan_inf(t, s):
assert not torch.isinf(t).any(), f"{s} is inf, {t}"
assert not torch.isnan(t).any(), f"{s} is nan, {t}"
================================================
FILE: smart/utils/weight_init.py
================================================
import torch.nn as nn
def weight_init(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
fan_in = m.in_channels / m.groups
fan_out = m.out_channels / m.groups
bound = (6.0 / (fan_in + fan_out)) ** 0.5
nn.init.uniform_(m.weight, -bound, bound)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.MultiheadAttention):
if m.in_proj_weight is not None:
fan_in = m.embed_dim
fan_out = m.embed_dim
bound = (6.0 / (fan_in + fan_out)) ** 0.5
nn.init.uniform_(m.in_proj_weight, -bound, bound)
else:
nn.init.xavier_uniform_(m.q_proj_weight)
nn.init.xavier_uniform_(m.k_proj_weight)
nn.init.xavier_uniform_(m.v_proj_weight)
if m.in_proj_bias is not None:
nn.init.zeros_(m.in_proj_bias)
nn.init.xavier_uniform_(m.out_proj.weight)
if m.out_proj.bias is not None:
nn.init.zeros_(m.out_proj.bias)
if m.bias_k is not None:
nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
if m.bias_v is not None:
nn.init.normal_(m.bias_v, mean=0.0, std=0.02)
elif isinstance(m, (nn.LSTM, nn.LSTMCell)):
for name, param in m.named_parameters():
if 'weight_ih' in name:
for ih in param.chunk(4, 0):
nn.init.xavier_uniform_(ih)
elif 'weight_hh' in name:
for hh in param.chunk(4, 0):
nn.init.orthogonal_(hh)
elif 'weight_hr' in name:
nn.init.xavier_uniform_(param)
elif 'bias_ih' in name:
nn.init.zeros_(param)
elif 'bias_hh' in name:
nn.init.zeros_(param)
nn.init.ones_(param.chunk(4, 0)[1])
elif isinstance(m, (nn.GRU, nn.GRUCell)):
for name, param in m.named_parameters():
if 'weight_ih' in name:
for ih in param.chunk(3, 0):
nn.init.xavier_uniform_(ih)
elif 'weight_hh' in name:
for hh in param.chunk(3, 0):
nn.init.orthogonal_(hh)
elif 'bias_ih' in name:
nn.init.zeros_(param)
elif 'bias_hh' in name:
nn.init.zeros_(param)
================================================
FILE: train.py
================================================
from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from smart.utils.config import load_config_act
from smart.datamodules import MultiDataModule
from smart.model import SMART
from smart.utils.log import Logging
if __name__ == '__main__':
parser = ArgumentParser()
Predictor_hash = {"smart": SMART, }
parser.add_argument('--config', type=str, default='configs/train/train_scalable.yaml')
parser.add_argument('--pretrain_ckpt', type=str, default="")
parser.add_argument('--ckpt_path', type=str, default="")
parser.add_argument('--save_ckpt_path', type=str, default="")
args = parser.parse_args()
config = load_config_act(args.config)
Predictor = Predictor_hash[config.Model.predictor]
strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True)
Data_config = config.Dataset
datamodule = MultiDataModule(**vars(Data_config))
if args.pretrain_ckpt == "":
model = Predictor(config.Model)
else:
logger = Logging().log(level='DEBUG')
model = Predictor(config.Model)
model.load_params_from_file(filename=args.pretrain_ckpt,
logger=logger)
trainer_config = config.Trainer
model_checkpoint = ModelCheckpoint(dirpath=args.save_ckpt_path,
filename="{epoch:02d}",
monitor='val_cls_acc',
every_n_epochs=1,
save_top_k=5,
mode='max')
lr_monitor = LearningRateMonitor(logging_interval='epoch')
trainer = pl.Trainer(accelerator=trainer_config.accelerator, devices=trainer_config.devices,
strategy=strategy,
accumulate_grad_batches=trainer_config.accumulate_grad_batches,
num_nodes=trainer_config.num_nodes,
callbacks=[model_checkpoint, lr_monitor],
max_epochs=trainer_config.max_epochs,
num_sanity_val_steps=0,
gradient_clip_val=0.5)
if args.ckpt_path == "":
trainer.fit(model,
datamodule)
else:
trainer.fit(model,
datamodule,
ckpt_path=args.ckpt_path)
================================================
FILE: val.py
================================================
from argparse import ArgumentParser
import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
from smart.datasets.scalable_dataset import MultiDataset
from smart.model import SMART
from smart.transforms import WaymoTargetBuilder
from smart.utils.config import load_config_act
from smart.utils.log import Logging
if __name__ == '__main__':
pl.seed_everything(2, workers=True)
parser = ArgumentParser()
parser.add_argument('--config', type=str, default="configs/validation/validation_scalable.yaml")
parser.add_argument('--pretrain_ckpt', type=str, default="")
parser.add_argument('--ckpt_path', type=str, default="")
parser.add_argument('--save_ckpt_path', type=str, default="")
args = parser.parse_args()
config = load_config_act(args.config)
data_config = config.Dataset
val_dataset = {
"scalable": MultiDataset,
}[data_config.dataset](root=data_config.root, split='val',
raw_dir=data_config.val_raw_dir,
processed_dir=data_config.val_processed_dir,
transform=WaymoTargetBuilder(config.Model.num_historical_steps, config.Model.decoder.num_future_steps))
dataloader = DataLoader(val_dataset, batch_size=data_config.batch_size, shuffle=False, num_workers=data_config.num_workers,
pin_memory=data_config.pin_memory, persistent_workers=True if data_config.num_workers > 0 else False)
Predictor = SMART
if args.pretrain_ckpt == "":
model = Predictor(config.Model)
else:
logger = Logging().log(level='DEBUG')
model = Predictor(config.Model)
model.load_params_from_file(filename=args.pretrain_ckpt,
logger=logger)
trainer_config = config.Trainer
trainer = pl.Trainer(accelerator=trainer_config.accelerator,
devices=trainer_config.devices,
strategy='ddp', num_sanity_val_steps=0)
trainer.validate(model, dataloader)