main 4024d12ca708 cached
7 files
68.4 KB
17.5k tokens
21 symbols
1 requests
Download .txt
Repository: kbrodt/waymo-motion-prediction-2021
Branch: main
Commit: 4024d12ca708
Files: 7
Total size: 68.4 KB

Directory structure:
gitextract_9tsvvauy/

├── README.md
├── prerender.py
├── requirements.txt
├── submission_proto/
│   └── motion_submission_pb2.py
├── submit.py
├── train.py
└── visualize.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# Waymo motion prediction challenge 2021: 3rd place solution 

![header](docs/header.png)

- 📜[**Paper**](https://arxiv.org/abs/2206.02163)  
- 🗨️[Presentation](./docs/waymo_motion_prediction_2021_3rd_place_solution_presentation.pdf)  
- 🎉[Announcement](https://youtu.be/eOL_rCK59ZI?t=6485)    
- 🛆[Motion Prediction Channel Website](https://waymo.com/open/challenges/2021/motion-prediction/)  
- 🛆[CVPR2021 workshop](http://cvpr2021.wad.vision/)  

-  **UPDATE❗** Related repo with [3rd place solution code](https://github.com/stepankonev/waymo-motion-prediction-challenge-2022-multipath-plus-plus) for Waymo Motion Prediction Challenge 2022 
-  **UPDATE❗** Related repo with [refactored code for MotionCNN](https://github.com/stepankonev/MotionCNN-Waymo-Open-Motion-Dataset)

## Team behind this solution:
1. Artsiom Sanakoyeu [[Homepage](https://gdude.de)] [[Twitter](https://twitter.com/artsiom_s)] [[Telegram Channel](https://t.me/gradientdude)] [[LinkedIn](https://www.linkedin.com/in/sanakoev)]
2. Stepan Konev [[LinkedIn]](https://www.linkedin.com/in/stepan-konev/)
3. Kirill Brodt [[GitHub]](https://github.com/kbrodt)

## Dataset

Download
[datasets](https://console.cloud.google.com/storage/browser/waymo_open_dataset_motion_v_1_0_0)
`uncompressed/tf_example/{training,validation,testing}`

## Prerender

Change paths to input dataset and output folders

```bash
python prerender.py \
    --data /home/data/waymo/training \
    --out ./train
    
python prerender.py \
    --data /home/data/waymo/validation \
    --out ./dev \
    --use-vectorize \
    --n-shards 1
    
python prerender.py \
    --data /home/data/waymo/testing \
    --out ./test \
    --use-vectorize \
    --n-shards 1
```

## Training

```bash
MODEL_NAME=xception71
python train.py \
    --train-data ./train \
    --dev-data ./dev \
    --save ./${MODEL_NAME} \
    --model ${MODEL_NAME} \
    --img-res 224 \
    --in-channels 25 \
    --time-limit 80 \
    --n-traj 6 \
    --lr 0.001 \
    --batch-size 48 \
    --n-epochs 120
```

## Submit

```bash
python submit.py \
    --test-data ./test/ \
    --model-path ${MODEL_PATH_TO_JIT} \
    --save ${SAVE}
```


## Visualize predictions

```bash
python visualize.py \
    --model ${MODEL_PATH_TO_JIT} \
    --data ${DATA_PATH} \
    --save ./viz
```

## Citation
If you find our work useful, please cite it as:
```
@misc{konev2022motioncnn,
      title={MotionCNN: A Strong Baseline for Motion Prediction in Autonomous Driving}, 
      author={Stepan Konev and Kirill Brodt and Artsiom Sanakoyeu},
      year={2022},
      eprint={2206.02163},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
```

## Related repos

* [Kaggle Lyft motion prediciton 3rd place solution](https://gdude.de/blog/2021-02-05/Kaggle-Lyft-solution)


================================================
FILE: prerender.py
================================================
import argparse
import multiprocessing
import os

import cv2
import numpy as np
import tensorflow as tf
from tqdm import tqdm

roadgraph_features = {
    "roadgraph_samples/dir": tf.io.FixedLenFeature(
        [20000, 3], tf.float32, default_value=None
    ),
    "roadgraph_samples/id": tf.io.FixedLenFeature(
        [20000, 1], tf.int64, default_value=None
    ),
    "roadgraph_samples/type": tf.io.FixedLenFeature(
        [20000, 1], tf.int64, default_value=None
    ),
    "roadgraph_samples/valid": tf.io.FixedLenFeature(
        [20000, 1], tf.int64, default_value=None
    ),
    "roadgraph_samples/xyz": tf.io.FixedLenFeature(
        [20000, 3], tf.float32, default_value=None
    ),
}

# Features of other agents.
state_features = {
    "state/id": tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    "state/type": tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    "state/is_sdc": tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    "state/tracks_to_predict": tf.io.FixedLenFeature(
        [128], tf.int64, default_value=None
    ),
    "state/current/bbox_yaw": tf.io.FixedLenFeature(
        [128, 1], tf.float32, default_value=None
    ),
    "state/current/height": tf.io.FixedLenFeature(
        [128, 1], tf.float32, default_value=None
    ),
    "state/current/length": tf.io.FixedLenFeature(
        [128, 1], tf.float32, default_value=None
    ),
    "state/current/timestamp_micros": tf.io.FixedLenFeature(
        [128, 1], tf.int64, default_value=None
    ),
    "state/current/valid": tf.io.FixedLenFeature(
        [128, 1], tf.int64, default_value=None
    ),
    "state/current/vel_yaw": tf.io.FixedLenFeature(
        [128, 1], tf.float32, default_value=None
    ),
    "state/current/velocity_x": tf.io.FixedLenFeature(
        [128, 1], tf.float32, default_value=None
    ),
    "state/current/velocity_y": tf.io.FixedLenFeature(
        [128, 1], tf.float32, default_value=None
    ),
    "state/current/speed": tf.io.FixedLenFeature(
        [128, 1], tf.float32, default_value=None
    ),
    "state/current/width": tf.io.FixedLenFeature(
        [128, 1], tf.float32, default_value=None
    ),
    "state/current/x": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    "state/current/y": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    "state/current/z": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    "state/future/bbox_yaw": tf.io.FixedLenFeature(
        [128, 80], tf.float32, default_value=None
    ),
    "state/future/height": tf.io.FixedLenFeature(
        [128, 80], tf.float32, default_value=None
    ),
    "state/future/length": tf.io.FixedLenFeature(
        [128, 80], tf.float32, default_value=None
    ),
    "state/future/timestamp_micros": tf.io.FixedLenFeature(
        [128, 80], tf.int64, default_value=None
    ),
    "state/future/valid": tf.io.FixedLenFeature(
        [128, 80], tf.int64, default_value=None
    ),
    "state/future/vel_yaw": tf.io.FixedLenFeature(
        [128, 80], tf.float32, default_value=None
    ),
    "state/future/velocity_x": tf.io.FixedLenFeature(
        [128, 80], tf.float32, default_value=None
    ),
    "state/future/velocity_y": tf.io.FixedLenFeature(
        [128, 80], tf.float32, default_value=None
    ),
    "state/future/width": tf.io.FixedLenFeature(
        [128, 80], tf.float32, default_value=None
    ),
    "state/future/x": tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    "state/future/y": tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    "state/future/z": tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    "state/past/bbox_yaw": tf.io.FixedLenFeature(
        [128, 10], tf.float32, default_value=None
    ),
    "state/past/height": tf.io.FixedLenFeature(
        [128, 10], tf.float32, default_value=None
    ),
    "state/past/length": tf.io.FixedLenFeature(
        [128, 10], tf.float32, default_value=None
    ),
    "state/past/timestamp_micros": tf.io.FixedLenFeature(
        [128, 10], tf.int64, default_value=None
    ),
    "state/past/valid": tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    "state/past/vel_yaw": tf.io.FixedLenFeature(
        [128, 10], tf.float32, default_value=None
    ),
    "state/past/velocity_x": tf.io.FixedLenFeature(
        [128, 10], tf.float32, default_value=None
    ),
    "state/past/velocity_y": tf.io.FixedLenFeature(
        [128, 10], tf.float32, default_value=None
    ),
    "state/past/speed": tf.io.FixedLenFeature(
        [128, 10], tf.float32, default_value=None
    ),
    "state/past/width": tf.io.FixedLenFeature(
        [128, 10], tf.float32, default_value=None
    ),
    "state/past/x": tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    "state/past/y": tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    "state/past/z": tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    "scenario/id": tf.io.FixedLenFeature([1], tf.string, default_value=None),
}

traffic_light_features = {
    "traffic_light_state/current/state": tf.io.FixedLenFeature(
        [1, 16], tf.int64, default_value=None
    ),
    "traffic_light_state/current/valid": tf.io.FixedLenFeature(
        [1, 16], tf.int64, default_value=None
    ),
    "traffic_light_state/current/id": tf.io.FixedLenFeature(
        [1, 16], tf.int64, default_value=None
    ),
    "traffic_light_state/current/x": tf.io.FixedLenFeature(
        [1, 16], tf.float32, default_value=None
    ),
    "traffic_light_state/current/y": tf.io.FixedLenFeature(
        [1, 16], tf.float32, default_value=None
    ),
    "traffic_light_state/current/z": tf.io.FixedLenFeature(
        [1, 16], tf.float32, default_value=None
    ),
    "traffic_light_state/past/state": tf.io.FixedLenFeature(
        [10, 16], tf.int64, default_value=None
    ),
    "traffic_light_state/past/valid": tf.io.FixedLenFeature(
        [10, 16], tf.int64, default_value=None
    ),
    # "traffic_light_state/past/id":
    # tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    "traffic_light_state/past/x": tf.io.FixedLenFeature(
        [10, 16], tf.float32, default_value=None
    ),
    "traffic_light_state/past/y": tf.io.FixedLenFeature(
        [10, 16], tf.float32, default_value=None
    ),
    "traffic_light_state/past/z": tf.io.FixedLenFeature(
        [10, 16], tf.float32, default_value=None
    ),
}

features_description = {}
features_description.update(roadgraph_features)
features_description.update(state_features)
features_description.update(traffic_light_features)
MAX_PIXEL_VALUE = 255
N_ROADS = 21
road_colors = [int(x) for x in np.linspace(1, MAX_PIXEL_VALUE, N_ROADS).astype("uint8")]
idx2type = ["unset", "vehicle", "pedestrian", "cyclist", "other"]


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", type=str, required=True, help="Path to raw data")
    parser.add_argument("--out", type=str, required=True, help="Path to save data")
    parser.add_argument(
        "--no-valid", action="store_true", help="Use data with flag `valid = 0`"
    )
    parser.add_argument(
        "--use-vectorize", action="store_true", help="Generate vector data"
    )
    parser.add_argument(
        "--n-jobs", type=int, default=20, required=False, help="Number of threads"
    )
    parser.add_argument(
        "--n-shards",
        type=int,
        default=8,
        required=False,
        help="Use `1/n_shards` of full dataset",
    )
    parser.add_argument(
        "--each",
        type=int,
        default=0,
        required=False,
        help="Take `each` sample in shard",
    )

    args = parser.parse_args()

    return args


def rasterize(
    tracks_to_predict,
    past_x,
    past_y,
    current_x,
    current_y,
    current_yaw,
    past_yaw,
    past_valid,
    current_valid,
    agent_type,
    roadlines_coords,
    roadlines_types,
    roadlines_valid,
    roadlines_ids,
    widths,
    lengths,
    agents_ids,
    tl_states,
    tl_ids,
    tl_valids,
    future_x,
    future_y,
    future_valid,
    scenario_id,
    validate,
    crop_size=512,
    raster_size=224,
    shift=2 ** 9,
    magic_const=3,
    n_channels=11,
):
    GRES = []
    displacement = np.array([[raster_size // 4, raster_size // 2]]) * shift
    tl_dict = {"green": set(), "yellow": set(), "red": set()}

    # Unknown = 0, Arrow_Stop = 1, Arrow_Caution = 2, Arrow_Go = 3, Stop = 4,
    # Caution = 5, Go = 6, Flashing_Stop = 7, Flashing_Caution = 8
    for tl_state, tl_id, tl_valid in zip(
        tl_states.flatten(), tl_ids.flatten(), tl_valids.flatten()
    ):
        if tl_valid == 0:
            continue
        if tl_state in [1, 4, 7]:
            tl_dict["red"].add(tl_id)
        if tl_state in [2, 5, 8]:
            tl_dict["yellow"].add(tl_id)
        if tl_state in [3, 6]:
            tl_dict["green"].add(tl_id)

    XY = np.concatenate(
        (
            np.expand_dims(np.concatenate((past_x, current_x), axis=1), axis=-1),
            np.expand_dims(np.concatenate((past_y, current_y), axis=1), axis=-1),
        ),
        axis=-1,
    )

    GT_XY = np.concatenate(
        (np.expand_dims(future_x, axis=-1), np.expand_dims(future_y, axis=-1)), axis=-1
    )

    YAWS = np.concatenate((past_yaw, current_yaw), axis=1)

    agents_valid = np.concatenate((past_valid, current_valid), axis=1)

    roadlines_valid = roadlines_valid.reshape(-1)
    roadlines_coords = (
        roadlines_coords[:, :2][roadlines_valid > 0]
        * shift
        * magic_const
        * raster_size
        / crop_size
    )
    roadlines_types = roadlines_types[roadlines_valid > 0]
    roadlines_ids = roadlines_ids.reshape(-1)[roadlines_valid > 0]

    for _, (
        xy,
        current_val,
        val,
        _,
        yaw,
        agent_id,
        gt_xy,
        future_val,
        predict,
    ) in enumerate(
        zip(
            XY,
            current_valid,
            agents_valid,
            agent_type,
            current_yaw.flatten(),
            agents_ids,
            GT_XY,
            future_valid,
            tracks_to_predict.flatten(),
        )
    ):
        if (not validate and future_val.sum() == 0) or (validate and predict == 0):
            continue
        if current_val == 0:
            continue

        RES_ROADMAP = (
            np.ones((raster_size, raster_size, 3), dtype=np.uint8) * MAX_PIXEL_VALUE
        )
        RES_EGO = [
            np.zeros((raster_size, raster_size, 1), dtype=np.uint8)
            for _ in range(n_channels)
        ]
        RES_OTHER = [
            np.zeros((raster_size, raster_size, 1), dtype=np.uint8)
            for _ in range(n_channels)
        ]

        xy_val = xy[val > 0]
        if len(xy_val) == 0:
            continue

        unscaled_center_xy = xy_val[-1].reshape(1, -1)
        center_xy = unscaled_center_xy * shift * magic_const * raster_size / crop_size
        rot_matrix = np.array(
            [
                [np.cos(yaw), -np.sin(yaw)],
                [np.sin(yaw), np.cos(yaw)],
            ]
        )

        centered_roadlines = (roadlines_coords - center_xy) @ rot_matrix + displacement
        centered_others = (
            XY.reshape(-1, 2) * shift * magic_const * raster_size / crop_size
            - center_xy
        ) @ rot_matrix + displacement
        centered_others = centered_others.reshape(128, n_channels, 2)
        centered_gt = (gt_xy - unscaled_center_xy) @ rot_matrix

        unique_road_ids = np.unique(roadlines_ids)
        for road_id in unique_road_ids:
            if road_id >= 0:
                roadline = centered_roadlines[roadlines_ids == road_id]
                road_type = roadlines_types[roadlines_ids == road_id].flatten()[0]

                road_color = road_colors[road_type]
                for c, rgb in zip(
                    ["green", "yellow", "red"],
                    [
                        (0, MAX_PIXEL_VALUE, 0),
                        (MAX_PIXEL_VALUE, 211, 0),
                        (MAX_PIXEL_VALUE, 0, 0),
                    ],
                ):
                    if road_id in tl_dict[c]:
                        road_color = rgb

                RES_ROADMAP = cv2.polylines(
                    RES_ROADMAP,
                    [roadline.astype(int)],
                    False,
                    road_color,
                    shift=9,
                )

        unique_agent_ids = np.unique(agents_ids)

        is_ego = False
        self_type = 0
        _tmp = 0
        for other_agent_id in unique_agent_ids:
            other_agent_id = int(other_agent_id)
            if other_agent_id < 1:
                continue
            if other_agent_id == agent_id:
                is_ego = True
                self_type = agent_type[agents_ids == other_agent_id]
            else:
                is_ego = False

            _tmp += 1
            agent_lane = centered_others[agents_ids == other_agent_id][0]
            agent_valid = agents_valid[agents_ids == other_agent_id]
            agent_yaw = YAWS[agents_ids == other_agent_id]

            agent_l = lengths[agents_ids == other_agent_id]
            agent_w = widths[agents_ids == other_agent_id]

            for timestamp, (coord, valid_coordinate, past_yaw,) in enumerate(
                zip(
                    agent_lane,
                    agent_valid.flatten(),
                    agent_yaw.flatten(),
                )
            ):
                if valid_coordinate == 0:
                    continue
                box_points = (
                    np.array(
                        [
                            -agent_l,
                            -agent_w,
                            agent_l,
                            -agent_w,
                            agent_l,
                            agent_w,
                            -agent_l,
                            agent_w,
                        ]
                    )
                    .reshape(4, 2)
                    .astype(np.float32)
                    * shift
                    * magic_const
                    / 2
                    * raster_size
                    / crop_size
                )

                box_points = (
                    box_points
                    @ np.array(
                        (
                            (np.cos(yaw - past_yaw), -np.sin(yaw - past_yaw)),
                            (np.sin(yaw - past_yaw), np.cos(yaw - past_yaw)),
                        )
                    ).reshape(2, 2)
                )

                _coord = np.array([coord])

                box_points = box_points + _coord
                box_points = box_points.reshape(1, -1, 2).astype(np.int32)

                if is_ego:
                    cv2.fillPoly(
                        RES_EGO[timestamp],
                        box_points,
                        color=MAX_PIXEL_VALUE,
                        shift=9,
                    )
                else:
                    cv2.fillPoly(
                        RES_OTHER[timestamp],
                        box_points,
                        color=MAX_PIXEL_VALUE,
                        shift=9,
                    )

        raster = np.concatenate([RES_ROADMAP] + RES_EGO + RES_OTHER, axis=2)

        raster_dict = {
            "object_id": agent_id,
            "raster": raster,
            "yaw": yaw,
            "shift": unscaled_center_xy,
            "_gt_marginal": gt_xy,
            "gt_marginal": centered_gt,
            "future_val_marginal": future_val,
            "gt_joint": GT_XY[tracks_to_predict.flatten() > 0],
            "future_val_joint": future_valid[tracks_to_predict.flatten() > 0],
            "scenario_id": scenario_id,
            "self_type": self_type,
        }

        GRES.append(raster_dict)

    return GRES


F2I = {
    "x": 0,
    "y": 1,
    "s": 2,
    "vel_yaw": 3,
    "bbox_yaw": 4,
    "l": 5,
    "w": 6,
    "agent_type_range": [7, 12],
    "lane_range": [13, 33],
    "lt_range": [34, 43],
    "global_idx": 44,
}


def ohe(N, n, zero):
    n = int(n)
    N = int(N)
    M = np.eye(N)
    diff = 0
    if zero:
        M = np.concatenate((np.zeros((1, N)), M), axis=0)
        diff = 1
    return M[n + diff]


def make_2d(arraylist):
    n = len(arraylist)
    k = arraylist[0].shape[0]
    a2d = np.zeros((n, k))
    for i in range(n):
        a2d[i] = arraylist[i]
    return a2d


def vectorize(
    past_x,
    current_x,
    past_y,
    current_y,
    past_valid,
    current_valid,
    past_speed,
    current_speed,
    past_velocity_yaw,
    current_velocity_yaw,
    past_bbox_yaw,
    current_bbox_yaw,
    Agent_id,
    Agent_type,
    Roadline_id,
    Roadline_type,
    Roadline_valid,
    Roadline_xy,
    Tl_rl_id,
    Tl_state,
    Tl_valid,
    W,
    L,
    tracks_to_predict,
    future_valid,
    validate,
    n_channels=11,
):

    XY = np.concatenate(
        (
            np.expand_dims(np.concatenate((past_x, current_x), axis=1), axis=-1),
            np.expand_dims(np.concatenate((past_y, current_y), axis=1), axis=-1),
        ),
        axis=-1,
    )

    Roadline_valid = Roadline_valid.flatten()
    RoadXY = Roadline_xy[:, :2][Roadline_valid > 0]
    Roadline_type = Roadline_type[Roadline_valid > 0].flatten()
    Roadline_id = Roadline_id[Roadline_valid > 0].flatten()

    tl_state = [[-1] for _ in range(9)]

    for lane_id, state, valid in zip(
        Tl_rl_id.flatten(), Tl_state.flatten(), Tl_valid.flatten()
    ):
        if valid == 0:
            continue
        tl_state[int(state)].append(lane_id)

    VALID = np.concatenate((past_valid, current_valid), axis=1)

    Speed = np.concatenate((past_speed, current_speed), axis=1)
    Vyaw = np.concatenate((past_velocity_yaw, current_velocity_yaw), axis=1)
    Bbox_yaw = np.concatenate((past_bbox_yaw, current_bbox_yaw), axis=1)

    GRES = []

    ROADLINES_STATE = []

    GLOBAL_IDX = -1

    unique_road_ids = np.unique(Roadline_id)
    for road_id in unique_road_ids:

        GLOBAL_IDX += 1

        roadline_coords = RoadXY[Roadline_id == road_id]
        roadline_type = Roadline_type[Roadline_id == road_id][0]

        for i, (x, y) in enumerate(roadline_coords):
            if i > 0 and i < len(roadline_coords) - 1 and i % 3 > 0:
                continue
            tmp = np.zeros(48)
            tmp[0] = x
            tmp[1] = y

            tmp[13:33] = ohe(20, roadline_type, True)

            tmp[44] = GLOBAL_IDX

            ROADLINES_STATE.append(tmp)

    ROADLINES_STATE = make_2d(ROADLINES_STATE)

    for (
        agent_id,
        xy,
        current_val,
        valid,
        _,
        bbox_yaw,
        _,
        _,
        _,
        future_val,
        predict,
    ) in zip(
        Agent_id,
        XY,
        current_valid,
        VALID,
        Speed,
        Bbox_yaw,
        Vyaw,
        W,
        L,
        future_valid,
        tracks_to_predict.flatten(),
    ):

        if (not validate and future_val.sum() == 0) or (validate and predict == 0):
            continue
        if current_val == 0:
            continue

        GLOBAL_IDX = -1
        RES = []

        xy_val = xy[valid > 0]
        if len(xy_val) == 0:
            continue

        centered_xy = xy_val[-1].copy().reshape(-1, 2)

        ANGLE = bbox_yaw[-1]

        rot_matrix = np.array(
            [
                [np.cos(ANGLE), -np.sin(ANGLE)],
                [np.sin(ANGLE), np.cos(ANGLE)],
            ]
        ).reshape(2, 2)

        local_roadlines_state = ROADLINES_STATE.copy()

        local_roadlines_state[:, :2] = (
            local_roadlines_state[:, :2] - centered_xy
        ) @ rot_matrix.astype(np.float64)

        local_XY = ((XY - centered_xy).reshape(-1, 2) @ rot_matrix).reshape(
            128, n_channels, 2
        )

        for (
            other_agent_id,
            other_agent_type,
            other_xy,
            other_valids,
            other_speeds,
            other_bbox_yaws,
            other_v_yaws,
            other_w,
            other_l,
            other_predict,
        ) in zip(
            Agent_id,
            Agent_type,
            local_XY,
            VALID,
            Speed,
            Bbox_yaw,
            Vyaw,
            W.flatten(),
            L.flatten(),
            tracks_to_predict.flatten(),
        ):
            if other_valids.sum() == 0:
                continue

            GLOBAL_IDX += 1
            for timestamp, (
                (x, y),
                v,
                other_speed,
                other_v_yaw,
                other_bbox_yaw,
            ) in enumerate(
                zip(other_xy, other_valids, other_speeds, other_v_yaws, other_bbox_yaws)
            ):
                if v == 0:
                    continue
                tmp = np.zeros(48)
                tmp[0] = x
                tmp[1] = y
                tmp[2] = other_speed
                tmp[3] = other_v_yaw - ANGLE
                tmp[4] = other_bbox_yaw - ANGLE
                tmp[5] = float(other_l)
                tmp[6] = float(other_w)

                tmp[7:12] = ohe(5, other_agent_type, True)

                tmp[43] = timestamp

                tmp[44] = GLOBAL_IDX
                tmp[45] = 1 if other_agent_id == agent_id else 0
                tmp[46] = other_predict
                tmp[47] = other_agent_id

                RES.append(tmp)
        local_roadlines_state[:, 44] = local_roadlines_state[:, 44] + GLOBAL_IDX + 1
        RES = np.concatenate((make_2d(RES), local_roadlines_state), axis=0)
        GRES.append(RES)

    return GRES


def merge(
    data, proc_id, validate, out_dir, use_vectorize=False, max_rand_int=10000000000
):
    parsed = tf.io.parse_single_example(data, features_description)
    raster_data = rasterize(
        parsed["state/tracks_to_predict"].numpy(),
        parsed["state/past/x"].numpy(),
        parsed["state/past/y"].numpy(),
        parsed["state/current/x"].numpy(),
        parsed["state/current/y"].numpy(),
        parsed["state/current/bbox_yaw"].numpy(),
        parsed["state/past/bbox_yaw"].numpy(),
        parsed["state/past/valid"].numpy(),
        parsed["state/current/valid"].numpy(),
        parsed["state/type"].numpy(),
        parsed["roadgraph_samples/xyz"].numpy(),
        parsed["roadgraph_samples/type"].numpy(),
        parsed["roadgraph_samples/valid"].numpy(),
        parsed["roadgraph_samples/id"].numpy(),
        parsed["state/current/width"].numpy(),
        parsed["state/current/length"].numpy(),
        parsed["state/id"].numpy(),
        parsed["traffic_light_state/current/state"].numpy(),
        parsed["traffic_light_state/current/id"].numpy(),
        parsed["traffic_light_state/current/valid"].numpy(),
        parsed["state/future/x"].numpy(),
        parsed["state/future/y"].numpy(),
        parsed["state/future/valid"].numpy(),
        parsed["scenario/id"].numpy()[0].decode("utf-8"),
        validate=validate,
    )

    if use_vectorize:
        vector_data = vectorize(
            parsed["state/past/x"].numpy(),
            parsed["state/current/x"].numpy(),
            parsed["state/past/y"].numpy(),
            parsed["state/current/y"].numpy(),
            parsed["state/past/valid"].numpy(),
            parsed["state/current/valid"].numpy(),
            parsed["state/past/speed"].numpy(),
            parsed["state/current/speed"].numpy(),
            parsed["state/past/vel_yaw"].numpy(),
            parsed["state/current/vel_yaw"].numpy(),
            parsed["state/past/bbox_yaw"].numpy(),
            parsed["state/current/bbox_yaw"].numpy(),
            parsed["state/id"].numpy(),
            parsed["state/type"].numpy(),
            parsed["roadgraph_samples/id"].numpy(),
            parsed["roadgraph_samples/type"].numpy(),
            parsed["roadgraph_samples/valid"].numpy(),
            parsed["roadgraph_samples/xyz"].numpy(),
            parsed["traffic_light_state/current/id"].numpy(),
            parsed["traffic_light_state/current/state"].numpy(),
            parsed["traffic_light_state/current/valid"].numpy(),
            parsed["state/current/width"].numpy(),
            parsed["state/current/length"].numpy(),
            parsed["state/tracks_to_predict"].numpy(),
            parsed["state/future/valid"].numpy(),
            validate=validate,
        )

    for i in range(len(raster_data)):
        if use_vectorize:
            raster_data[i]["vector_data"] = vector_data[i].astype(np.float16)

        r = np.random.randint(max_rand_int)
        filename = f"{idx2type[int(raster_data[i]['self_type'])]}_{proc_id}_{str(i).zfill(5)}_{r}.npz"
        np.savez_compressed(os.path.join(out_dir, filename), **raster_data[i])


def main():
    args = parse_arguments()
    print(args)

    if not os.path.exists(args.out):
        os.mkdir(args.out)

    files = os.listdir(args.data)
    dataset = tf.data.TFRecordDataset(
        [os.path.join(args.data, f) for f in files], num_parallel_reads=1
    )
    if args.n_shards > 1:
        dataset = dataset.shard(args.n_shards, args.each)

    p = multiprocessing.Pool(args.n_jobs)
    proc_id = 0
    res = []
    for data in tqdm(dataset.as_numpy_iterator()):
        proc_id += 1
        res.append(
            p.apply_async(
                merge,
                kwds=dict(
                    data=data,
                    proc_id=proc_id,
                    validate=not args.no_valid,
                    out_dir=args.out,
                    use_vectorize=args.use_vectorize,
                ),
            )
        )

    for r in tqdm(res):
        r.get()


if __name__ == "__main__":
    main()


================================================
FILE: requirements.txt
================================================
numpy
opencv-python
tensorflow
timm
torch
tqdm


================================================
FILE: submission_proto/motion_submission_pb2.py
================================================
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler.  DO NOT EDIT!
# source: motion_submission.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()




DESCRIPTOR = _descriptor.FileDescriptor(
  name='motion_submission.proto',
  package='waymo.open_dataset',
  syntax='proto2',
  serialized_options=None,
  create_key=_descriptor._internal_create_key,
  serialized_pb=b'\n\x17motion_submission.proto\x12\x12waymo.open_dataset\"8\n\nTrajectory\x12\x14\n\x08\x63\x65nter_x\x18\x02 \x03(\x02\x42\x02\x10\x01\x12\x14\n\x08\x63\x65nter_y\x18\x03 \x03(\x02\x42\x02\x10\x01\"Z\n\x10ScoredTrajectory\x12\x32\n\ntrajectory\x18\x01 \x01(\x0b\x32\x1e.waymo.open_dataset.Trajectory\x12\x12\n\nconfidence\x18\x02 \x01(\x02\"g\n\x16SingleObjectPrediction\x12\x11\n\tobject_id\x18\x01 \x01(\x05\x12:\n\x0ctrajectories\x18\x02 \x03(\x0b\x32$.waymo.open_dataset.ScoredTrajectory\"P\n\rPredictionSet\x12?\n\x0bpredictions\x18\x01 \x03(\x0b\x32*.waymo.open_dataset.SingleObjectPrediction\"Y\n\x10ObjectTrajectory\x12\x11\n\tobject_id\x18\x01 \x01(\x05\x12\x32\n\ntrajectory\x18\x02 \x01(\x0b\x32\x1e.waymo.open_dataset.Trajectory\"g\n\x15ScoredJointTrajectory\x12:\n\x0ctrajectories\x18\x02 \x03(\x0b\x32$.waymo.open_dataset.ObjectTrajectory\x12\x12\n\nconfidence\x18\x03 \x01(\x02\"X\n\x0fJointPrediction\x12\x45\n\x12joint_trajectories\x18\x01 \x03(\x0b\x32).waymo.open_dataset.ScoredJointTrajectory\"\xc7\x01\n\x1c\x43hallengeScenarioPredictions\x12\x13\n\x0bscenario_id\x18\x01 \x01(\t\x12?\n\x12single_predictions\x18\x02 \x01(\x0b\x32!.waymo.open_dataset.PredictionSetH\x00\x12?\n\x10joint_prediction\x18\x03 \x01(\x0b\x32#.waymo.open_dataset.JointPredictionH\x00\x42\x10\n\x0eprediction_set\"\x96\x03\n\x19MotionChallengeSubmission\x12\x14\n\x0c\x61\x63\x63ount_name\x18\x03 \x01(\t\x12\x1a\n\x12unique_method_name\x18\x04 \x01(\t\x12\x0f\n\x07\x61uthors\x18\x05 \x03(\t\x12\x13\n\x0b\x61\x66\x66iliation\x18\x06 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x07 \x01(\t\x12\x13\n\x0bmethod_link\x18\x08 \x01(\t\x12U\n\x0fsubmission_type\x18\x02 \x01(\x0e\x32<.waymo.open_dataset.MotionChallengeSubmission.SubmissionType\x12N\n\x14scenario_predictions\x18\x01 \x03(\x0b\x32\x30.waymo.open_dataset.ChallengeScenarioPredictions\"P\n\x0eSubmissionType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x15\n\x11MOTION_PREDICTION\x10\x01\x12\x1a\n\x16INTERACTION_PREDICTION\x10\x02'
)



_MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE = _descriptor.EnumDescriptor(
  name='SubmissionType',
  full_name='waymo.open_dataset.MotionChallengeSubmission.SubmissionType',
  filename=None,
  file=DESCRIPTOR,
  create_key=_descriptor._internal_create_key,
  values=[
    _descriptor.EnumValueDescriptor(
      name='UNKNOWN', index=0, number=0,
      serialized_options=None,
      type=None,
      create_key=_descriptor._internal_create_key),
    _descriptor.EnumValueDescriptor(
      name='MOTION_PREDICTION', index=1, number=1,
      serialized_options=None,
      type=None,
      create_key=_descriptor._internal_create_key),
    _descriptor.EnumValueDescriptor(
      name='INTERACTION_PREDICTION', index=2, number=2,
      serialized_options=None,
      type=None,
      create_key=_descriptor._internal_create_key),
  ],
  containing_type=None,
  serialized_options=None,
  serialized_start=1199,
  serialized_end=1279,
)
_sym_db.RegisterEnumDescriptor(_MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE)


_TRAJECTORY = _descriptor.Descriptor(
  name='Trajectory',
  full_name='waymo.open_dataset.Trajectory',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  create_key=_descriptor._internal_create_key,
  fields=[
    _descriptor.FieldDescriptor(
      name='center_x', full_name='waymo.open_dataset.Trajectory.center_x', index=0,
      number=2, type=2, cpp_type=6, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=b'\020\001', file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='center_y', full_name='waymo.open_dataset.Trajectory.center_y', index=1,
      number=3, type=2, cpp_type=6, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=b'\020\001', file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=47,
  serialized_end=103,
)


_SCOREDTRAJECTORY = _descriptor.Descriptor(
  name='ScoredTrajectory',
  full_name='waymo.open_dataset.ScoredTrajectory',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  create_key=_descriptor._internal_create_key,
  fields=[
    _descriptor.FieldDescriptor(
      name='trajectory', full_name='waymo.open_dataset.ScoredTrajectory.trajectory', index=0,
      number=1, type=11, cpp_type=10, label=1,
      has_default_value=False, default_value=None,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='confidence', full_name='waymo.open_dataset.ScoredTrajectory.confidence', index=1,
      number=2, type=2, cpp_type=6, label=1,
      has_default_value=False, default_value=float(0),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=105,
  serialized_end=195,
)


_SINGLEOBJECTPREDICTION = _descriptor.Descriptor(
  name='SingleObjectPrediction',
  full_name='waymo.open_dataset.SingleObjectPrediction',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  create_key=_descriptor._internal_create_key,
  fields=[
    _descriptor.FieldDescriptor(
      name='object_id', full_name='waymo.open_dataset.SingleObjectPrediction.object_id', index=0,
      number=1, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='trajectories', full_name='waymo.open_dataset.SingleObjectPrediction.trajectories', index=1,
      number=2, type=11, cpp_type=10, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=197,
  serialized_end=300,
)


_PREDICTIONSET = _descriptor.Descriptor(
  name='PredictionSet',
  full_name='waymo.open_dataset.PredictionSet',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  create_key=_descriptor._internal_create_key,
  fields=[
    _descriptor.FieldDescriptor(
      name='predictions', full_name='waymo.open_dataset.PredictionSet.predictions', index=0,
      number=1, type=11, cpp_type=10, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=302,
  serialized_end=382,
)


_OBJECTTRAJECTORY = _descriptor.Descriptor(
  name='ObjectTrajectory',
  full_name='waymo.open_dataset.ObjectTrajectory',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  create_key=_descriptor._internal_create_key,
  fields=[
    _descriptor.FieldDescriptor(
      name='object_id', full_name='waymo.open_dataset.ObjectTrajectory.object_id', index=0,
      number=1, type=5, cpp_type=1, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='trajectory', full_name='waymo.open_dataset.ObjectTrajectory.trajectory', index=1,
      number=2, type=11, cpp_type=10, label=1,
      has_default_value=False, default_value=None,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=384,
  serialized_end=473,
)


_SCOREDJOINTTRAJECTORY = _descriptor.Descriptor(
  name='ScoredJointTrajectory',
  full_name='waymo.open_dataset.ScoredJointTrajectory',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  create_key=_descriptor._internal_create_key,
  fields=[
    _descriptor.FieldDescriptor(
      name='trajectories', full_name='waymo.open_dataset.ScoredJointTrajectory.trajectories', index=0,
      number=2, type=11, cpp_type=10, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='confidence', full_name='waymo.open_dataset.ScoredJointTrajectory.confidence', index=1,
      number=3, type=2, cpp_type=6, label=1,
      has_default_value=False, default_value=float(0),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=475,
  serialized_end=578,
)


_JOINTPREDICTION = _descriptor.Descriptor(
  name='JointPrediction',
  full_name='waymo.open_dataset.JointPrediction',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  create_key=_descriptor._internal_create_key,
  fields=[
    _descriptor.FieldDescriptor(
      name='joint_trajectories', full_name='waymo.open_dataset.JointPrediction.joint_trajectories', index=0,
      number=1, type=11, cpp_type=10, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=580,
  serialized_end=668,
)


_CHALLENGESCENARIOPREDICTIONS = _descriptor.Descriptor(
  name='ChallengeScenarioPredictions',
  full_name='waymo.open_dataset.ChallengeScenarioPredictions',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  create_key=_descriptor._internal_create_key,
  fields=[
    _descriptor.FieldDescriptor(
      name='scenario_id', full_name='waymo.open_dataset.ChallengeScenarioPredictions.scenario_id', index=0,
      number=1, type=9, cpp_type=9, label=1,
      has_default_value=False, default_value=b"".decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='single_predictions', full_name='waymo.open_dataset.ChallengeScenarioPredictions.single_predictions', index=1,
      number=2, type=11, cpp_type=10, label=1,
      has_default_value=False, default_value=None,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='joint_prediction', full_name='waymo.open_dataset.ChallengeScenarioPredictions.joint_prediction', index=2,
      number=3, type=11, cpp_type=10, label=1,
      has_default_value=False, default_value=None,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
    _descriptor.OneofDescriptor(
      name='prediction_set', full_name='waymo.open_dataset.ChallengeScenarioPredictions.prediction_set',
      index=0, containing_type=None,
      create_key=_descriptor._internal_create_key,
    fields=[]),
  ],
  serialized_start=671,
  serialized_end=870,
)


_MOTIONCHALLENGESUBMISSION = _descriptor.Descriptor(
  name='MotionChallengeSubmission',
  full_name='waymo.open_dataset.MotionChallengeSubmission',
  filename=None,
  file=DESCRIPTOR,
  containing_type=None,
  create_key=_descriptor._internal_create_key,
  fields=[
    _descriptor.FieldDescriptor(
      name='account_name', full_name='waymo.open_dataset.MotionChallengeSubmission.account_name', index=0,
      number=3, type=9, cpp_type=9, label=1,
      has_default_value=False, default_value=b"".decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='unique_method_name', full_name='waymo.open_dataset.MotionChallengeSubmission.unique_method_name', index=1,
      number=4, type=9, cpp_type=9, label=1,
      has_default_value=False, default_value=b"".decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='authors', full_name='waymo.open_dataset.MotionChallengeSubmission.authors', index=2,
      number=5, type=9, cpp_type=9, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='affiliation', full_name='waymo.open_dataset.MotionChallengeSubmission.affiliation', index=3,
      number=6, type=9, cpp_type=9, label=1,
      has_default_value=False, default_value=b"".decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='description', full_name='waymo.open_dataset.MotionChallengeSubmission.description', index=4,
      number=7, type=9, cpp_type=9, label=1,
      has_default_value=False, default_value=b"".decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='method_link', full_name='waymo.open_dataset.MotionChallengeSubmission.method_link', index=5,
      number=8, type=9, cpp_type=9, label=1,
      has_default_value=False, default_value=b"".decode('utf-8'),
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='submission_type', full_name='waymo.open_dataset.MotionChallengeSubmission.submission_type', index=6,
      number=2, type=14, cpp_type=8, label=1,
      has_default_value=False, default_value=0,
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
    _descriptor.FieldDescriptor(
      name='scenario_predictions', full_name='waymo.open_dataset.MotionChallengeSubmission.scenario_predictions', index=7,
      number=1, type=11, cpp_type=10, label=3,
      has_default_value=False, default_value=[],
      message_type=None, enum_type=None, containing_type=None,
      is_extension=False, extension_scope=None,
      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
  ],
  extensions=[
  ],
  nested_types=[],
  enum_types=[
    _MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE,
  ],
  serialized_options=None,
  is_extendable=False,
  syntax='proto2',
  extension_ranges=[],
  oneofs=[
  ],
  serialized_start=873,
  serialized_end=1279,
)

_SCOREDTRAJECTORY.fields_by_name['trajectory'].message_type = _TRAJECTORY
_SINGLEOBJECTPREDICTION.fields_by_name['trajectories'].message_type = _SCOREDTRAJECTORY
_PREDICTIONSET.fields_by_name['predictions'].message_type = _SINGLEOBJECTPREDICTION
_OBJECTTRAJECTORY.fields_by_name['trajectory'].message_type = _TRAJECTORY
_SCOREDJOINTTRAJECTORY.fields_by_name['trajectories'].message_type = _OBJECTTRAJECTORY
_JOINTPREDICTION.fields_by_name['joint_trajectories'].message_type = _SCOREDJOINTTRAJECTORY
_CHALLENGESCENARIOPREDICTIONS.fields_by_name['single_predictions'].message_type = _PREDICTIONSET
_CHALLENGESCENARIOPREDICTIONS.fields_by_name['joint_prediction'].message_type = _JOINTPREDICTION
_CHALLENGESCENARIOPREDICTIONS.oneofs_by_name['prediction_set'].fields.append(
  _CHALLENGESCENARIOPREDICTIONS.fields_by_name['single_predictions'])
_CHALLENGESCENARIOPREDICTIONS.fields_by_name['single_predictions'].containing_oneof = _CHALLENGESCENARIOPREDICTIONS.oneofs_by_name['prediction_set']
_CHALLENGESCENARIOPREDICTIONS.oneofs_by_name['prediction_set'].fields.append(
  _CHALLENGESCENARIOPREDICTIONS.fields_by_name['joint_prediction'])
_CHALLENGESCENARIOPREDICTIONS.fields_by_name['joint_prediction'].containing_oneof = _CHALLENGESCENARIOPREDICTIONS.oneofs_by_name['prediction_set']
_MOTIONCHALLENGESUBMISSION.fields_by_name['submission_type'].enum_type = _MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE
_MOTIONCHALLENGESUBMISSION.fields_by_name['scenario_predictions'].message_type = _CHALLENGESCENARIOPREDICTIONS
_MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE.containing_type = _MOTIONCHALLENGESUBMISSION
DESCRIPTOR.message_types_by_name['Trajectory'] = _TRAJECTORY
DESCRIPTOR.message_types_by_name['ScoredTrajectory'] = _SCOREDTRAJECTORY
DESCRIPTOR.message_types_by_name['SingleObjectPrediction'] = _SINGLEOBJECTPREDICTION
DESCRIPTOR.message_types_by_name['PredictionSet'] = _PREDICTIONSET
DESCRIPTOR.message_types_by_name['ObjectTrajectory'] = _OBJECTTRAJECTORY
DESCRIPTOR.message_types_by_name['ScoredJointTrajectory'] = _SCOREDJOINTTRAJECTORY
DESCRIPTOR.message_types_by_name['JointPrediction'] = _JOINTPREDICTION
DESCRIPTOR.message_types_by_name['ChallengeScenarioPredictions'] = _CHALLENGESCENARIOPREDICTIONS
DESCRIPTOR.message_types_by_name['MotionChallengeSubmission'] = _MOTIONCHALLENGESUBMISSION
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

Trajectory = _reflection.GeneratedProtocolMessageType('Trajectory', (_message.Message,), {
  'DESCRIPTOR' : _TRAJECTORY,
  '__module__' : 'motion_submission_pb2'
  # @@protoc_insertion_point(class_scope:waymo.open_dataset.Trajectory)
  })
_sym_db.RegisterMessage(Trajectory)

ScoredTrajectory = _reflection.GeneratedProtocolMessageType('ScoredTrajectory', (_message.Message,), {
  'DESCRIPTOR' : _SCOREDTRAJECTORY,
  '__module__' : 'motion_submission_pb2'
  # @@protoc_insertion_point(class_scope:waymo.open_dataset.ScoredTrajectory)
  })
_sym_db.RegisterMessage(ScoredTrajectory)

SingleObjectPrediction = _reflection.GeneratedProtocolMessageType('SingleObjectPrediction', (_message.Message,), {
  'DESCRIPTOR' : _SINGLEOBJECTPREDICTION,
  '__module__' : 'motion_submission_pb2'
  # @@protoc_insertion_point(class_scope:waymo.open_dataset.SingleObjectPrediction)
  })
_sym_db.RegisterMessage(SingleObjectPrediction)

PredictionSet = _reflection.GeneratedProtocolMessageType('PredictionSet', (_message.Message,), {
  'DESCRIPTOR' : _PREDICTIONSET,
  '__module__' : 'motion_submission_pb2'
  # @@protoc_insertion_point(class_scope:waymo.open_dataset.PredictionSet)
  })
_sym_db.RegisterMessage(PredictionSet)

ObjectTrajectory = _reflection.GeneratedProtocolMessageType('ObjectTrajectory', (_message.Message,), {
  'DESCRIPTOR' : _OBJECTTRAJECTORY,
  '__module__' : 'motion_submission_pb2'
  # @@protoc_insertion_point(class_scope:waymo.open_dataset.ObjectTrajectory)
  })
_sym_db.RegisterMessage(ObjectTrajectory)

ScoredJointTrajectory = _reflection.GeneratedProtocolMessageType('ScoredJointTrajectory', (_message.Message,), {
  'DESCRIPTOR' : _SCOREDJOINTTRAJECTORY,
  '__module__' : 'motion_submission_pb2'
  # @@protoc_insertion_point(class_scope:waymo.open_dataset.ScoredJointTrajectory)
  })
_sym_db.RegisterMessage(ScoredJointTrajectory)

JointPrediction = _reflection.GeneratedProtocolMessageType('JointPrediction', (_message.Message,), {
  'DESCRIPTOR' : _JOINTPREDICTION,
  '__module__' : 'motion_submission_pb2'
  # @@protoc_insertion_point(class_scope:waymo.open_dataset.JointPrediction)
  })
_sym_db.RegisterMessage(JointPrediction)

ChallengeScenarioPredictions = _reflection.GeneratedProtocolMessageType('ChallengeScenarioPredictions', (_message.Message,), {
  'DESCRIPTOR' : _CHALLENGESCENARIOPREDICTIONS,
  '__module__' : 'motion_submission_pb2'
  # @@protoc_insertion_point(class_scope:waymo.open_dataset.ChallengeScenarioPredictions)
  })
_sym_db.RegisterMessage(ChallengeScenarioPredictions)

MotionChallengeSubmission = _reflection.GeneratedProtocolMessageType('MotionChallengeSubmission', (_message.Message,), {
  'DESCRIPTOR' : _MOTIONCHALLENGESUBMISSION,
  '__module__' : 'motion_submission_pb2'
  # @@protoc_insertion_point(class_scope:waymo.open_dataset.MotionChallengeSubmission)
  })
_sym_db.RegisterMessage(MotionChallengeSubmission)


_TRAJECTORY.fields_by_name['center_x']._options = None
_TRAJECTORY.fields_by_name['center_y']._options = None
# @@protoc_insertion_point(module_scope)


================================================
FILE: submit.py
================================================
import argparse

# chage this if you have problem
import sys
sys.path.insert(1, "~/.local/lib/python3.6/site-packages")


import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from submission_proto import motion_submission_pb2
from train import WaymoLoader, Model


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--test-data", type=str, required=True, help="Path to rasterized data"
    )
    parser.add_argument(
        "--model-path", type=str, required=True, help="Path to CNN model"
    )
    parser.add_argument(
        "--time-limit", type=int, required=False, default=80, help="Number time steps"
    )
    parser.add_argument(
        "--save", type=str, required=True, help="Path to save predictions"
    )
    parser.add_argument(
        "--model-name", type=str, required=False, help="Model name"
    )

    parser.add_argument("--account-name", required=False, default="")
    parser.add_argument("--authors", required=False, default="")
    parser.add_argument("--method-name", required=False, default="SimpleCNNOnRaster")

    parser.add_argument("--batch-size", type=int, required=False, default=128)

    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    print(args)

    if args.model_path.endswith("pth"):
        model = Model(args.model_name)
        model.load_state_dict(torch.load(args.model_path)["model_state_dict"])
    else:
        model = torch.jit.load(args.model_path)

    model.cuda()
    model.eval()

    dataset = WaymoLoader(args.test_data, is_test=True)
    loader = DataLoader(
        dataset, batch_size=args.batch_size, num_workers=min(args.batch_size, 16)
    )

    RES = {}
    with torch.no_grad():
        for x, center, yaw, agent_id, scenario_id, _, _ in tqdm(loader):
            x = x.cuda()
            confidences_logits, logits = model(x)
            confidences = torch.softmax(confidences_logits, dim=1)

            logits = logits.cpu().numpy()
            confidences = confidences.cpu().numpy()
            agent_id = agent_id.cpu().numpy()
            center = center.cpu().numpy()
            yaw = yaw.cpu().numpy()
            for p, conf, aid, sid, c, y in zip(
                logits, confidences, agent_id, scenario_id, center, yaw
            ):
                if sid not in RES:
                    RES[sid] = []

                RES[sid].append(
                    {"aid": aid, "conf": conf, "pred": p, "yaw": -y, "center": c}
                )

    motion_challenge_submission = motion_submission_pb2.MotionChallengeSubmission()
    motion_challenge_submission.account_name = args.account_name
    motion_challenge_submission.authors.extend(args.authors.split(","))
    motion_challenge_submission.submission_type = (
        motion_submission_pb2.MotionChallengeSubmission.SubmissionType.MOTION_PREDICTION
    )
    motion_challenge_submission.unique_method_name = args.method_name

    selector = np.arange(4, args.time_limit + 1, 5)
    for scenario_id, data in tqdm(RES.items()):
        scenario_predictions = motion_challenge_submission.scenario_predictions.add()
        scenario_predictions.scenario_id = scenario_id
        prediction_set = scenario_predictions.single_predictions

        for d in data:
            predictions = prediction_set.predictions.add()
            predictions.object_id = int(d["aid"])

            y = d["yaw"]
            rot_matrix = np.array([
                [np.cos(y), -np.sin(y)],
                [np.sin(y), np.cos(y)],
            ])

            for i in np.argsort(-d["conf"]):
                scored_trajectory = predictions.trajectories.add()
                scored_trajectory.confidence = d["conf"][i]

                trajectory = scored_trajectory.trajectory

                p = d["pred"][i][selector] @ rot_matrix + d["center"]

                trajectory.center_x.extend(p[:, 0])
                trajectory.center_y.extend(p[:, 1])

    with open(args.save, "wb") as f:
        f.write(motion_challenge_submission.SerializeToString())


if __name__ == "__main__":
    main()


================================================
FILE: train.py
================================================
import argparse
import os

import numpy as np
import timm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

IMG_RES = 224
IN_CHANNELS = 25
TL = 80
N_TRAJS = 6


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train-data", type=str, required=True, help="Path to rasterized data"
    )
    parser.add_argument(
        "--dev-data", type=str, required=True, help="Path to rasterized data"
    )
    parser.add_argument(
        "--img-res",
        type=int,
        required=False,
        default=IMG_RES,
        help="Input images resolution",
    )
    parser.add_argument(
        "--in-channels",
        type=int,
        required=False,
        default=IN_CHANNELS,
        help="Input raster channels",
    )
    parser.add_argument(
        "--time-limit",
        type=int,
        required=False,
        default=TL,
        help="Number time step to predict",
    )
    parser.add_argument(
        "--n-traj",
        type=int,
        required=False,
        default=N_TRAJS,
        help="Number of trajectories to predict",
    )
    parser.add_argument(
        "--save", type=str, required=True, help="Path to save model and logs"
    )

    parser.add_argument(
        "--model", type=str, required=False, default="xception71", help="CNN model name"
    )
    parser.add_argument("--lr", type=float, required=False, default=1e-3)
    parser.add_argument("--batch-size", type=int, required=False, default=48)
    parser.add_argument("--n-epochs", type=int, required=False, default=60)

    parser.add_argument("--valid-limit", type=int, required=False, default=24 * 100)
    parser.add_argument(
        "--n-monitor-train",
        type=int,
        required=False,
        default=10,
        help="Validate model each `n-validate` steps",
    )
    parser.add_argument(
        "--n-monitor-validate",
        type=int,
        required=False,
        default=1000,
        help="Validate model each `n-validate` steps",
    )

    args = parser.parse_args()

    return args


class Model(nn.Module):
    def __init__(
        self, model_name, in_channels=IN_CHANNELS, time_limit=TL, n_traj=N_TRAJS
    ):
        super().__init__()

        self.n_traj = n_traj
        self.time_limit = time_limit
        self.model = timm.create_model(
            model_name,
            pretrained=True,
            in_chans=in_channels,
            num_classes=self.n_traj * 2 * self.time_limit + self.n_traj,
        )


    def forward(self, x):
        outputs = self.model(x)

        confidences_logits, logits = (
            outputs[:, : self.n_traj],
            outputs[:, self.n_traj :],
        )
        logits = logits.view(-1, self.n_traj, self.time_limit, 2)

        return confidences_logits, logits


def pytorch_neg_multi_log_likelihood_batch(gt, logits, confidences, avails):
    """
    Compute a negative log-likelihood for the multi-modal scenario.
    Args:
        gt (Tensor): array of shape (bs)x(time)x(2D coords)
        logits (Tensor): array of shape (bs)x(modes)x(time)x(2D coords)
        confidences (Tensor): array of shape (bs)x(modes) with a confidence for each mode in each sample
        avails (Tensor): array of shape (bs)x(time) with the availability for each gt timestep
    Returns:
        Tensor: negative log-likelihood for this example, a single float number
    """

    # convert to (batch_size, num_modes, future_len, num_coords)
    gt = torch.unsqueeze(gt, 1)  # add modes
    avails = avails[:, None, :, None]  # add modes and cords

    # error (batch_size, num_modes, future_len)
    error = torch.sum(
        ((gt - logits) * avails) ** 2, dim=-1
    )  # reduce coords and use availability

    with np.errstate(
        divide="ignore"
    ):  # when confidence is 0 log goes to -inf, but we're fine with it
        # error (batch_size, num_modes)
        error = nn.functional.log_softmax(confidences, dim=1) - 0.5 * torch.sum(
            error, dim=-1
        )  # reduce time

    # error (batch_size, num_modes)
    error = -torch.logsumexp(error, dim=-1, keepdim=True)

    return torch.mean(error)


class WaymoLoader(Dataset):
    def __init__(self, directory, limit=0, return_vector=False, is_test=False):
        files = os.listdir(directory)
        self.files = [os.path.join(directory, f) for f in files if f.endswith(".npz")]

        if limit > 0:
            self.files = self.files[:limit]
        else:
            self.files = sorted(self.files)

        self.return_vector = return_vector
        self.is_test = is_test

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        filename = self.files[idx]
        data = np.load(filename, allow_pickle=True)

        raster = data["raster"].astype("float32")
        raster = raster.transpose(2, 1, 0) / 255

        if self.is_test:
            center = data["shift"]
            yaw = data["yaw"]
            agent_id = data["object_id"]
            scenario_id = data["scenario_id"]

            return (
                raster,
                center,
                yaw,
                agent_id,
                str(scenario_id),
                data["_gt_marginal"],
                data["gt_marginal"],
            )

        trajectory = data["gt_marginal"]

        is_available = data["future_val_marginal"]

        if self.return_vector:
            return raster, trajectory, is_available, data["vector_data"]

        return raster, trajectory, is_available


def main():
    args = parse_args()

    summary_writer = SummaryWriter(os.path.join(args.save, "logs"))

    train_path = args.train_data
    dev_path = args.dev_data
    path_to_save = args.save
    if not os.path.exists(path_to_save):
        os.mkdir(path_to_save)

    dataset = WaymoLoader(train_path)

    batch_size = args.batch_size
    num_workers = min(16, batch_size)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=False,
        persistent_workers=True,
    )

    val_dataset = WaymoLoader(dev_path, limit=args.valid_limit)
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size * 2,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=False,
        persistent_workers=True,
    )

    model_name = args.model
    time_limit = args.time_limit
    n_traj = args.n_traj
    model = Model(
        model_name, in_channels=args.in_channels, time_limit=time_limit, n_traj=n_traj
    )
    model.cuda()

    lr = args.lr
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=2 * len(dataloader),
        T_mult=1,
        eta_min=max(1e-2 * lr, 1e-6),
        last_epoch=-1,
    )

    start_iter = 0
    best_loss = float("+inf")
    glosses = []

    tr_it = iter(dataloader)
    n_epochs = args.n_epochs
    progress_bar = tqdm(range(start_iter, len(dataloader) * n_epochs))

    saver = lambda name: torch.save(
        {
            "score": best_loss,
            "iteration": iteration,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "loss": loss.item(),
        },
        os.path.join(path_to_save, name),
    )

    for iteration in progress_bar:
        model.train()
        try:
            x, y, is_available = next(tr_it)
        except StopIteration:
            tr_it = iter(dataloader)
            x, y, is_available = next(tr_it)

        x, y, is_available = map(lambda x: x.cuda(), (x, y, is_available))

        optimizer.zero_grad()

        confidences_logits, logits = model(x)

        loss = pytorch_neg_multi_log_likelihood_batch(
            y, logits, confidences_logits, is_available
        )
        loss.backward()
        optimizer.step()
        scheduler.step()

        glosses.append(loss.item())
        if (iteration + 1) % args.n_monitor_train == 0:
            progress_bar.set_description(
                f"loss: {loss.item():.3}"
                f" avg: {np.mean(glosses[-100:]):.2}"
                f" {scheduler.get_last_lr()[-1]:.3}"
            )
            summary_writer.add_scalar("train/loss", loss.item(), iteration)
            summary_writer.add_scalar("lr", scheduler.get_last_lr()[-1], iteration)

        if (iteration + 1) % args.n_monitor_validate == 0:
            optimizer.zero_grad()
            model.eval()
            with torch.no_grad():
                val_losses = []
                for x, y, is_available in val_dataloader:
                    x, y, is_available = map(lambda x: x.cuda(), (x, y, is_available))

                    confidences_logits, logits = model(x)
                    loss = pytorch_neg_multi_log_likelihood_batch(
                        y, logits, confidences_logits, is_available
                    )
                    val_losses.append(loss.item())

                summary_writer.add_scalar("dev/loss", np.mean(val_losses), iteration)

            saver("model_last.pth")

            mean_val_loss = np.mean(val_losses)
            if mean_val_loss < best_loss:
                best_loss = mean_val_loss
                saver("model_best.pth")

                model.eval()
                with torch.no_grad():
                    traced_model = torch.jit.trace(
                        model,
                        torch.rand(
                            1, args.in_channels, args.img_res, args.img_res
                        ).cuda(),
                    )

                traced_model.save(os.path.join(path_to_save, "model_best.pt"))
                del traced_model


if __name__ == "__main__":
    main()


================================================
FILE: visualize.py
================================================
import argparse
import os

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
from torch.utils.data import DataLoader

from train import WaymoLoader, pytorch_neg_multi_log_likelihood_batch


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--data", type=str, required=True)
    parser.add_argument("--save", type=str, required=True)
    parser.add_argument("--n-samples", type=int, required=False, default=50)
    parser.add_argument("--use-top1", action="store_true")

    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    if not os.path.exists(args.save):
        os.mkdir(args.save)

    model = torch.jit.load(args.model).cuda().eval()
    loader = DataLoader(
        WaymoLoader(args.data, return_vector=True),
        batch_size=1,
        num_workers=1,
        shuffle=False,
    )

    iii = 0
    with torch.no_grad():
        for x, y, is_available, vector_data in loader:
            x, y, is_available = map(lambda x: x.cuda(), (x, y, is_available))

            confidences_logits, logits = model(x)

            argmax = confidences_logits.argmax()
            if args.use_top1:
                confidences_logits = confidences_logits[:, argmax].unsqueeze(1)
                logits = logits[:, argmax].unsqueeze(1)

            loss = pytorch_neg_multi_log_likelihood_batch(
                y, logits, confidences_logits, is_available
            )
            confidences = torch.softmax(confidences_logits, dim=1)
            V = vector_data[0]

            X, idx = V[:, :44], V[:, 44].flatten()

            figure(figsize=(15, 15), dpi=80)
            for i in np.unique(idx):
                _X = X[idx == i]
                if _X[:, 5:12].sum() > 0:
                    plt.plot(_X[:, 0], _X[:, 1], linewidth=4, color="red")
                else:
                    plt.plot(_X[:, 0], _X[:, 1], color="black")
                plt.xlim([-224 // 4, 224 // 4])
                plt.ylim([-224 // 4, 224 // 4])

            logits = logits.squeeze(0).cpu().numpy()
            y = y.squeeze(0).cpu().numpy()
            is_available = is_available.squeeze(0).long().cpu().numpy()
            confidences = confidences.squeeze(0).cpu().numpy()
            plt.plot(
                y[is_available > 0][::10, 0],
                y[is_available > 0][::10, 1],
                "-o",
                label="gt",
            )

            plt.plot(
                logits[confidences.argmax()][is_available > 0][::10, 0],
                logits[confidences.argmax()][is_available > 0][::10, 1],
                "-o",
                label="pred top 1",
            )
            if not args.use_top1:
                for traj_id in range(len(logits)):
                    if traj_id == argmax:
                        continue

                    alpha = confidences[traj_id].item()
                    plt.plot(
                        logits[traj_id][is_available > 0][::10, 0],
                        logits[traj_id][is_available > 0][::10, 1],
                        "-o",
                        label=f"pred {traj_id} {alpha:.3f}",
                        alpha=alpha,
                    )


            plt.title(loss.item())
            plt.legend()
            plt.savefig(
                os.path.join(args.save, f"{iii:0>2}_{loss.item():.3f}.png")
            )
            plt.close()

            iii += 1
            if iii == args.n_samples:
                break


if __name__ == "__main__":
    main()
Download .txt
gitextract_9tsvvauy/

├── README.md
├── prerender.py
├── requirements.txt
├── submission_proto/
│   └── motion_submission_pb2.py
├── submit.py
├── train.py
└── visualize.py
Download .txt
SYMBOL INDEX (21 symbols across 4 files)

FILE: prerender.py
  function parse_arguments (line 181) | def parse_arguments():
  function rasterize (line 214) | def rasterize(
  function ohe (line 498) | def ohe(N, n, zero):
  function make_2d (line 509) | def make_2d(arraylist):
  function vectorize (line 518) | def vectorize(
  function merge (line 728) | def merge(
  function main (line 799) | def main():

FILE: submit.py
  function parse_args (line 17) | def parse_args():
  function main (line 46) | def main():

FILE: train.py
  function parse_args (line 18) | def parse_args():
  class Model (line 86) | class Model(nn.Module):
    method __init__ (line 87) | def __init__(
    method forward (line 102) | def forward(self, x):
  function pytorch_neg_multi_log_likelihood_batch (line 114) | def pytorch_neg_multi_log_likelihood_batch(gt, logits, confidences, avai...
  class WaymoLoader (line 149) | class WaymoLoader(Dataset):
    method __init__ (line 150) | def __init__(self, directory, limit=0, return_vector=False, is_test=Fa...
    method __len__ (line 162) | def __len__(self):
    method __getitem__ (line 165) | def __getitem__(self, idx):
  function main (line 198) | def main():

FILE: visualize.py
  function parse_args (line 13) | def parse_args():
  function main (line 26) | def main():
Condensed preview — 7 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (73K chars).
[
  {
    "path": "README.md",
    "chars": 2769,
    "preview": "# Waymo motion prediction challenge 2021: 3rd place solution \n\n![header](docs/header.png)\n\n- 📜[**Paper**](https://arxiv."
  },
  {
    "path": "prerender.py",
    "chars": 25777,
    "preview": "import argparse\nimport multiprocessing\nimport os\n\nimport cv2\nimport numpy as np\nimport tensorflow as tf\nfrom tqdm import"
  },
  {
    "path": "requirements.txt",
    "chars": 47,
    "preview": "numpy\nopencv-python\ntensorflow\ntimm\ntorch\ntqdm\n"
  },
  {
    "path": "submission_proto/motion_submission_pb2.py",
    "chars": 23710,
    "preview": "# -*- coding: utf-8 -*-\n# Generated by the protocol buffer compiler.  DO NOT EDIT!\n# source: motion_submission.proto\n\"\"\""
  },
  {
    "path": "submit.py",
    "chars": 4129,
    "preview": "import argparse\n\n# chage this if you have problem\nimport sys\nsys.path.insert(1, \"~/.local/lib/python3.6/site-packages\")\n"
  },
  {
    "path": "train.py",
    "chars": 9943,
    "preview": "import argparse\nimport os\n\nimport numpy as np\nimport timm\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data impor"
  },
  {
    "path": "visualize.py",
    "chars": 3621,
    "preview": "import argparse\nimport os\n\nimport numpy as np\nimport torch\nfrom matplotlib import pyplot as plt\nfrom matplotlib.pyplot i"
  }
]

About this extraction

This page contains the full source code of the kbrodt/waymo-motion-prediction-2021 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 7 files (68.4 KB), approximately 17.5k tokens, and a symbol index with 21 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!