[
  {
    "path": "README.md",
    "content": "# Waymo motion prediction challenge 2021: 3rd place solution \n\n![header](docs/header.png)\n\n- 📜[**Paper**](https://arxiv.org/abs/2206.02163)  \n- 🗨️[Presentation](./docs/waymo_motion_prediction_2021_3rd_place_solution_presentation.pdf)  \n- 🎉[Announcement](https://youtu.be/eOL_rCK59ZI?t=6485)    \n- 🛆[Motion Prediction Channel Website](https://waymo.com/open/challenges/2021/motion-prediction/)  \n- 🛆[CVPR2021 workshop](http://cvpr2021.wad.vision/)  \n\n-  **UPDATE❗** Related repo with [3rd place solution code](https://github.com/stepankonev/waymo-motion-prediction-challenge-2022-multipath-plus-plus) for Waymo Motion Prediction Challenge 2022 \n-  **UPDATE❗** Related repo with [refactored code for MotionCNN](https://github.com/stepankonev/MotionCNN-Waymo-Open-Motion-Dataset)\n\n## Team behind this solution:\n1. Artsiom Sanakoyeu [[Homepage](https://gdude.de)] [[Twitter](https://twitter.com/artsiom_s)] [[Telegram Channel](https://t.me/gradientdude)] [[LinkedIn](https://www.linkedin.com/in/sanakoev)]\n2. Stepan Konev [[LinkedIn]](https://www.linkedin.com/in/stepan-konev/)\n3. Kirill Brodt [[GitHub]](https://github.com/kbrodt)\n\n## Dataset\n\nDownload\n[datasets](https://console.cloud.google.com/storage/browser/waymo_open_dataset_motion_v_1_0_0)\n`uncompressed/tf_example/{training,validation,testing}`\n\n## Prerender\n\nChange paths to input dataset and output folders\n\n```bash\npython prerender.py \\\n    --data /home/data/waymo/training \\\n    --out ./train\n    \npython prerender.py \\\n    --data /home/data/waymo/validation \\\n    --out ./dev \\\n    --use-vectorize \\\n    --n-shards 1\n    \npython prerender.py \\\n    --data /home/data/waymo/testing \\\n    --out ./test \\\n    --use-vectorize \\\n    --n-shards 1\n```\n\n## Training\n\n```bash\nMODEL_NAME=xception71\npython train.py \\\n    --train-data ./train \\\n    --dev-data ./dev \\\n    --save ./${MODEL_NAME} \\\n    --model ${MODEL_NAME} \\\n    --img-res 224 \\\n    --in-channels 25 \\\n    --time-limit 80 \\\n    --n-traj 6 \\\n    --lr 0.001 \\\n    --batch-size 48 \\\n    --n-epochs 120\n```\n\n## Submit\n\n```bash\npython submit.py \\\n    --test-data ./test/ \\\n    --model-path ${MODEL_PATH_TO_JIT} \\\n    --save ${SAVE}\n```\n\n\n## Visualize predictions\n\n```bash\npython visualize.py \\\n    --model ${MODEL_PATH_TO_JIT} \\\n    --data ${DATA_PATH} \\\n    --save ./viz\n```\n\n## Citation\nIf you find our work useful, please cite it as:\n```\n@misc{konev2022motioncnn,\n      title={MotionCNN: A Strong Baseline for Motion Prediction in Autonomous Driving}, \n      author={Stepan Konev and Kirill Brodt and Artsiom Sanakoyeu},\n      year={2022},\n      eprint={2206.02163},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n```\n\n## Related repos\n\n* [Kaggle Lyft motion prediciton 3rd place solution](https://gdude.de/blog/2021-02-05/Kaggle-Lyft-solution)\n"
  },
  {
    "path": "prerender.py",
    "content": "import argparse\nimport multiprocessing\nimport os\n\nimport cv2\nimport numpy as np\nimport tensorflow as tf\nfrom tqdm import tqdm\n\nroadgraph_features = {\n    \"roadgraph_samples/dir\": tf.io.FixedLenFeature(\n        [20000, 3], tf.float32, default_value=None\n    ),\n    \"roadgraph_samples/id\": tf.io.FixedLenFeature(\n        [20000, 1], tf.int64, default_value=None\n    ),\n    \"roadgraph_samples/type\": tf.io.FixedLenFeature(\n        [20000, 1], tf.int64, default_value=None\n    ),\n    \"roadgraph_samples/valid\": tf.io.FixedLenFeature(\n        [20000, 1], tf.int64, default_value=None\n    ),\n    \"roadgraph_samples/xyz\": tf.io.FixedLenFeature(\n        [20000, 3], tf.float32, default_value=None\n    ),\n}\n\n# Features of other agents.\nstate_features = {\n    \"state/id\": tf.io.FixedLenFeature([128], tf.float32, default_value=None),\n    \"state/type\": tf.io.FixedLenFeature([128], tf.float32, default_value=None),\n    \"state/is_sdc\": tf.io.FixedLenFeature([128], tf.int64, default_value=None),\n    \"state/tracks_to_predict\": tf.io.FixedLenFeature(\n        [128], tf.int64, default_value=None\n    ),\n    \"state/current/bbox_yaw\": tf.io.FixedLenFeature(\n        [128, 1], tf.float32, default_value=None\n    ),\n    \"state/current/height\": tf.io.FixedLenFeature(\n        [128, 1], tf.float32, default_value=None\n    ),\n    \"state/current/length\": tf.io.FixedLenFeature(\n        [128, 1], tf.float32, default_value=None\n    ),\n    \"state/current/timestamp_micros\": tf.io.FixedLenFeature(\n        [128, 1], tf.int64, default_value=None\n    ),\n    \"state/current/valid\": tf.io.FixedLenFeature(\n        [128, 1], tf.int64, default_value=None\n    ),\n    \"state/current/vel_yaw\": tf.io.FixedLenFeature(\n        [128, 1], tf.float32, default_value=None\n    ),\n    \"state/current/velocity_x\": tf.io.FixedLenFeature(\n        [128, 1], tf.float32, default_value=None\n    ),\n    \"state/current/velocity_y\": tf.io.FixedLenFeature(\n        [128, 1], tf.float32, default_value=None\n    ),\n    \"state/current/speed\": tf.io.FixedLenFeature(\n        [128, 1], tf.float32, default_value=None\n    ),\n    \"state/current/width\": tf.io.FixedLenFeature(\n        [128, 1], tf.float32, default_value=None\n    ),\n    \"state/current/x\": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),\n    \"state/current/y\": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),\n    \"state/current/z\": tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),\n    \"state/future/bbox_yaw\": tf.io.FixedLenFeature(\n        [128, 80], tf.float32, default_value=None\n    ),\n    \"state/future/height\": tf.io.FixedLenFeature(\n        [128, 80], tf.float32, default_value=None\n    ),\n    \"state/future/length\": tf.io.FixedLenFeature(\n        [128, 80], tf.float32, default_value=None\n    ),\n    \"state/future/timestamp_micros\": tf.io.FixedLenFeature(\n        [128, 80], tf.int64, default_value=None\n    ),\n    \"state/future/valid\": tf.io.FixedLenFeature(\n        [128, 80], tf.int64, default_value=None\n    ),\n    \"state/future/vel_yaw\": tf.io.FixedLenFeature(\n        [128, 80], tf.float32, default_value=None\n    ),\n    \"state/future/velocity_x\": tf.io.FixedLenFeature(\n        [128, 80], tf.float32, default_value=None\n    ),\n    \"state/future/velocity_y\": tf.io.FixedLenFeature(\n        [128, 80], tf.float32, default_value=None\n    ),\n    \"state/future/width\": tf.io.FixedLenFeature(\n        [128, 80], tf.float32, default_value=None\n    ),\n    \"state/future/x\": tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),\n    \"state/future/y\": tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),\n    \"state/future/z\": tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),\n    \"state/past/bbox_yaw\": tf.io.FixedLenFeature(\n        [128, 10], tf.float32, default_value=None\n    ),\n    \"state/past/height\": tf.io.FixedLenFeature(\n        [128, 10], tf.float32, default_value=None\n    ),\n    \"state/past/length\": tf.io.FixedLenFeature(\n        [128, 10], tf.float32, default_value=None\n    ),\n    \"state/past/timestamp_micros\": tf.io.FixedLenFeature(\n        [128, 10], tf.int64, default_value=None\n    ),\n    \"state/past/valid\": tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),\n    \"state/past/vel_yaw\": tf.io.FixedLenFeature(\n        [128, 10], tf.float32, default_value=None\n    ),\n    \"state/past/velocity_x\": tf.io.FixedLenFeature(\n        [128, 10], tf.float32, default_value=None\n    ),\n    \"state/past/velocity_y\": tf.io.FixedLenFeature(\n        [128, 10], tf.float32, default_value=None\n    ),\n    \"state/past/speed\": tf.io.FixedLenFeature(\n        [128, 10], tf.float32, default_value=None\n    ),\n    \"state/past/width\": tf.io.FixedLenFeature(\n        [128, 10], tf.float32, default_value=None\n    ),\n    \"state/past/x\": tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),\n    \"state/past/y\": tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),\n    \"state/past/z\": tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),\n    \"scenario/id\": tf.io.FixedLenFeature([1], tf.string, default_value=None),\n}\n\ntraffic_light_features = {\n    \"traffic_light_state/current/state\": tf.io.FixedLenFeature(\n        [1, 16], tf.int64, default_value=None\n    ),\n    \"traffic_light_state/current/valid\": tf.io.FixedLenFeature(\n        [1, 16], tf.int64, default_value=None\n    ),\n    \"traffic_light_state/current/id\": tf.io.FixedLenFeature(\n        [1, 16], tf.int64, default_value=None\n    ),\n    \"traffic_light_state/current/x\": tf.io.FixedLenFeature(\n        [1, 16], tf.float32, default_value=None\n    ),\n    \"traffic_light_state/current/y\": tf.io.FixedLenFeature(\n        [1, 16], tf.float32, default_value=None\n    ),\n    \"traffic_light_state/current/z\": tf.io.FixedLenFeature(\n        [1, 16], tf.float32, default_value=None\n    ),\n    \"traffic_light_state/past/state\": tf.io.FixedLenFeature(\n        [10, 16], tf.int64, default_value=None\n    ),\n    \"traffic_light_state/past/valid\": tf.io.FixedLenFeature(\n        [10, 16], tf.int64, default_value=None\n    ),\n    # \"traffic_light_state/past/id\":\n    # tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),\n    \"traffic_light_state/past/x\": tf.io.FixedLenFeature(\n        [10, 16], tf.float32, default_value=None\n    ),\n    \"traffic_light_state/past/y\": tf.io.FixedLenFeature(\n        [10, 16], tf.float32, default_value=None\n    ),\n    \"traffic_light_state/past/z\": tf.io.FixedLenFeature(\n        [10, 16], tf.float32, default_value=None\n    ),\n}\n\nfeatures_description = {}\nfeatures_description.update(roadgraph_features)\nfeatures_description.update(state_features)\nfeatures_description.update(traffic_light_features)\nMAX_PIXEL_VALUE = 255\nN_ROADS = 21\nroad_colors = [int(x) for x in np.linspace(1, MAX_PIXEL_VALUE, N_ROADS).astype(\"uint8\")]\nidx2type = [\"unset\", \"vehicle\", \"pedestrian\", \"cyclist\", \"other\"]\n\n\ndef parse_arguments():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data\", type=str, required=True, help=\"Path to raw data\")\n    parser.add_argument(\"--out\", type=str, required=True, help=\"Path to save data\")\n    parser.add_argument(\n        \"--no-valid\", action=\"store_true\", help=\"Use data with flag `valid = 0`\"\n    )\n    parser.add_argument(\n        \"--use-vectorize\", action=\"store_true\", help=\"Generate vector data\"\n    )\n    parser.add_argument(\n        \"--n-jobs\", type=int, default=20, required=False, help=\"Number of threads\"\n    )\n    parser.add_argument(\n        \"--n-shards\",\n        type=int,\n        default=8,\n        required=False,\n        help=\"Use `1/n_shards` of full dataset\",\n    )\n    parser.add_argument(\n        \"--each\",\n        type=int,\n        default=0,\n        required=False,\n        help=\"Take `each` sample in shard\",\n    )\n\n    args = parser.parse_args()\n\n    return args\n\n\ndef rasterize(\n    tracks_to_predict,\n    past_x,\n    past_y,\n    current_x,\n    current_y,\n    current_yaw,\n    past_yaw,\n    past_valid,\n    current_valid,\n    agent_type,\n    roadlines_coords,\n    roadlines_types,\n    roadlines_valid,\n    roadlines_ids,\n    widths,\n    lengths,\n    agents_ids,\n    tl_states,\n    tl_ids,\n    tl_valids,\n    future_x,\n    future_y,\n    future_valid,\n    scenario_id,\n    validate,\n    crop_size=512,\n    raster_size=224,\n    shift=2 ** 9,\n    magic_const=3,\n    n_channels=11,\n):\n    GRES = []\n    displacement = np.array([[raster_size // 4, raster_size // 2]]) * shift\n    tl_dict = {\"green\": set(), \"yellow\": set(), \"red\": set()}\n\n    # Unknown = 0, Arrow_Stop = 1, Arrow_Caution = 2, Arrow_Go = 3, Stop = 4,\n    # Caution = 5, Go = 6, Flashing_Stop = 7, Flashing_Caution = 8\n    for tl_state, tl_id, tl_valid in zip(\n        tl_states.flatten(), tl_ids.flatten(), tl_valids.flatten()\n    ):\n        if tl_valid == 0:\n            continue\n        if tl_state in [1, 4, 7]:\n            tl_dict[\"red\"].add(tl_id)\n        if tl_state in [2, 5, 8]:\n            tl_dict[\"yellow\"].add(tl_id)\n        if tl_state in [3, 6]:\n            tl_dict[\"green\"].add(tl_id)\n\n    XY = np.concatenate(\n        (\n            np.expand_dims(np.concatenate((past_x, current_x), axis=1), axis=-1),\n            np.expand_dims(np.concatenate((past_y, current_y), axis=1), axis=-1),\n        ),\n        axis=-1,\n    )\n\n    GT_XY = np.concatenate(\n        (np.expand_dims(future_x, axis=-1), np.expand_dims(future_y, axis=-1)), axis=-1\n    )\n\n    YAWS = np.concatenate((past_yaw, current_yaw), axis=1)\n\n    agents_valid = np.concatenate((past_valid, current_valid), axis=1)\n\n    roadlines_valid = roadlines_valid.reshape(-1)\n    roadlines_coords = (\n        roadlines_coords[:, :2][roadlines_valid > 0]\n        * shift\n        * magic_const\n        * raster_size\n        / crop_size\n    )\n    roadlines_types = roadlines_types[roadlines_valid > 0]\n    roadlines_ids = roadlines_ids.reshape(-1)[roadlines_valid > 0]\n\n    for _, (\n        xy,\n        current_val,\n        val,\n        _,\n        yaw,\n        agent_id,\n        gt_xy,\n        future_val,\n        predict,\n    ) in enumerate(\n        zip(\n            XY,\n            current_valid,\n            agents_valid,\n            agent_type,\n            current_yaw.flatten(),\n            agents_ids,\n            GT_XY,\n            future_valid,\n            tracks_to_predict.flatten(),\n        )\n    ):\n        if (not validate and future_val.sum() == 0) or (validate and predict == 0):\n            continue\n        if current_val == 0:\n            continue\n\n        RES_ROADMAP = (\n            np.ones((raster_size, raster_size, 3), dtype=np.uint8) * MAX_PIXEL_VALUE\n        )\n        RES_EGO = [\n            np.zeros((raster_size, raster_size, 1), dtype=np.uint8)\n            for _ in range(n_channels)\n        ]\n        RES_OTHER = [\n            np.zeros((raster_size, raster_size, 1), dtype=np.uint8)\n            for _ in range(n_channels)\n        ]\n\n        xy_val = xy[val > 0]\n        if len(xy_val) == 0:\n            continue\n\n        unscaled_center_xy = xy_val[-1].reshape(1, -1)\n        center_xy = unscaled_center_xy * shift * magic_const * raster_size / crop_size\n        rot_matrix = np.array(\n            [\n                [np.cos(yaw), -np.sin(yaw)],\n                [np.sin(yaw), np.cos(yaw)],\n            ]\n        )\n\n        centered_roadlines = (roadlines_coords - center_xy) @ rot_matrix + displacement\n        centered_others = (\n            XY.reshape(-1, 2) * shift * magic_const * raster_size / crop_size\n            - center_xy\n        ) @ rot_matrix + displacement\n        centered_others = centered_others.reshape(128, n_channels, 2)\n        centered_gt = (gt_xy - unscaled_center_xy) @ rot_matrix\n\n        unique_road_ids = np.unique(roadlines_ids)\n        for road_id in unique_road_ids:\n            if road_id >= 0:\n                roadline = centered_roadlines[roadlines_ids == road_id]\n                road_type = roadlines_types[roadlines_ids == road_id].flatten()[0]\n\n                road_color = road_colors[road_type]\n                for c, rgb in zip(\n                    [\"green\", \"yellow\", \"red\"],\n                    [\n                        (0, MAX_PIXEL_VALUE, 0),\n                        (MAX_PIXEL_VALUE, 211, 0),\n                        (MAX_PIXEL_VALUE, 0, 0),\n                    ],\n                ):\n                    if road_id in tl_dict[c]:\n                        road_color = rgb\n\n                RES_ROADMAP = cv2.polylines(\n                    RES_ROADMAP,\n                    [roadline.astype(int)],\n                    False,\n                    road_color,\n                    shift=9,\n                )\n\n        unique_agent_ids = np.unique(agents_ids)\n\n        is_ego = False\n        self_type = 0\n        _tmp = 0\n        for other_agent_id in unique_agent_ids:\n            other_agent_id = int(other_agent_id)\n            if other_agent_id < 1:\n                continue\n            if other_agent_id == agent_id:\n                is_ego = True\n                self_type = agent_type[agents_ids == other_agent_id]\n            else:\n                is_ego = False\n\n            _tmp += 1\n            agent_lane = centered_others[agents_ids == other_agent_id][0]\n            agent_valid = agents_valid[agents_ids == other_agent_id]\n            agent_yaw = YAWS[agents_ids == other_agent_id]\n\n            agent_l = lengths[agents_ids == other_agent_id]\n            agent_w = widths[agents_ids == other_agent_id]\n\n            for timestamp, (coord, valid_coordinate, past_yaw,) in enumerate(\n                zip(\n                    agent_lane,\n                    agent_valid.flatten(),\n                    agent_yaw.flatten(),\n                )\n            ):\n                if valid_coordinate == 0:\n                    continue\n                box_points = (\n                    np.array(\n                        [\n                            -agent_l,\n                            -agent_w,\n                            agent_l,\n                            -agent_w,\n                            agent_l,\n                            agent_w,\n                            -agent_l,\n                            agent_w,\n                        ]\n                    )\n                    .reshape(4, 2)\n                    .astype(np.float32)\n                    * shift\n                    * magic_const\n                    / 2\n                    * raster_size\n                    / crop_size\n                )\n\n                box_points = (\n                    box_points\n                    @ np.array(\n                        (\n                            (np.cos(yaw - past_yaw), -np.sin(yaw - past_yaw)),\n                            (np.sin(yaw - past_yaw), np.cos(yaw - past_yaw)),\n                        )\n                    ).reshape(2, 2)\n                )\n\n                _coord = np.array([coord])\n\n                box_points = box_points + _coord\n                box_points = box_points.reshape(1, -1, 2).astype(np.int32)\n\n                if is_ego:\n                    cv2.fillPoly(\n                        RES_EGO[timestamp],\n                        box_points,\n                        color=MAX_PIXEL_VALUE,\n                        shift=9,\n                    )\n                else:\n                    cv2.fillPoly(\n                        RES_OTHER[timestamp],\n                        box_points,\n                        color=MAX_PIXEL_VALUE,\n                        shift=9,\n                    )\n\n        raster = np.concatenate([RES_ROADMAP] + RES_EGO + RES_OTHER, axis=2)\n\n        raster_dict = {\n            \"object_id\": agent_id,\n            \"raster\": raster,\n            \"yaw\": yaw,\n            \"shift\": unscaled_center_xy,\n            \"_gt_marginal\": gt_xy,\n            \"gt_marginal\": centered_gt,\n            \"future_val_marginal\": future_val,\n            \"gt_joint\": GT_XY[tracks_to_predict.flatten() > 0],\n            \"future_val_joint\": future_valid[tracks_to_predict.flatten() > 0],\n            \"scenario_id\": scenario_id,\n            \"self_type\": self_type,\n        }\n\n        GRES.append(raster_dict)\n\n    return GRES\n\n\nF2I = {\n    \"x\": 0,\n    \"y\": 1,\n    \"s\": 2,\n    \"vel_yaw\": 3,\n    \"bbox_yaw\": 4,\n    \"l\": 5,\n    \"w\": 6,\n    \"agent_type_range\": [7, 12],\n    \"lane_range\": [13, 33],\n    \"lt_range\": [34, 43],\n    \"global_idx\": 44,\n}\n\n\ndef ohe(N, n, zero):\n    n = int(n)\n    N = int(N)\n    M = np.eye(N)\n    diff = 0\n    if zero:\n        M = np.concatenate((np.zeros((1, N)), M), axis=0)\n        diff = 1\n    return M[n + diff]\n\n\ndef make_2d(arraylist):\n    n = len(arraylist)\n    k = arraylist[0].shape[0]\n    a2d = np.zeros((n, k))\n    for i in range(n):\n        a2d[i] = arraylist[i]\n    return a2d\n\n\ndef vectorize(\n    past_x,\n    current_x,\n    past_y,\n    current_y,\n    past_valid,\n    current_valid,\n    past_speed,\n    current_speed,\n    past_velocity_yaw,\n    current_velocity_yaw,\n    past_bbox_yaw,\n    current_bbox_yaw,\n    Agent_id,\n    Agent_type,\n    Roadline_id,\n    Roadline_type,\n    Roadline_valid,\n    Roadline_xy,\n    Tl_rl_id,\n    Tl_state,\n    Tl_valid,\n    W,\n    L,\n    tracks_to_predict,\n    future_valid,\n    validate,\n    n_channels=11,\n):\n\n    XY = np.concatenate(\n        (\n            np.expand_dims(np.concatenate((past_x, current_x), axis=1), axis=-1),\n            np.expand_dims(np.concatenate((past_y, current_y), axis=1), axis=-1),\n        ),\n        axis=-1,\n    )\n\n    Roadline_valid = Roadline_valid.flatten()\n    RoadXY = Roadline_xy[:, :2][Roadline_valid > 0]\n    Roadline_type = Roadline_type[Roadline_valid > 0].flatten()\n    Roadline_id = Roadline_id[Roadline_valid > 0].flatten()\n\n    tl_state = [[-1] for _ in range(9)]\n\n    for lane_id, state, valid in zip(\n        Tl_rl_id.flatten(), Tl_state.flatten(), Tl_valid.flatten()\n    ):\n        if valid == 0:\n            continue\n        tl_state[int(state)].append(lane_id)\n\n    VALID = np.concatenate((past_valid, current_valid), axis=1)\n\n    Speed = np.concatenate((past_speed, current_speed), axis=1)\n    Vyaw = np.concatenate((past_velocity_yaw, current_velocity_yaw), axis=1)\n    Bbox_yaw = np.concatenate((past_bbox_yaw, current_bbox_yaw), axis=1)\n\n    GRES = []\n\n    ROADLINES_STATE = []\n\n    GLOBAL_IDX = -1\n\n    unique_road_ids = np.unique(Roadline_id)\n    for road_id in unique_road_ids:\n\n        GLOBAL_IDX += 1\n\n        roadline_coords = RoadXY[Roadline_id == road_id]\n        roadline_type = Roadline_type[Roadline_id == road_id][0]\n\n        for i, (x, y) in enumerate(roadline_coords):\n            if i > 0 and i < len(roadline_coords) - 1 and i % 3 > 0:\n                continue\n            tmp = np.zeros(48)\n            tmp[0] = x\n            tmp[1] = y\n\n            tmp[13:33] = ohe(20, roadline_type, True)\n\n            tmp[44] = GLOBAL_IDX\n\n            ROADLINES_STATE.append(tmp)\n\n    ROADLINES_STATE = make_2d(ROADLINES_STATE)\n\n    for (\n        agent_id,\n        xy,\n        current_val,\n        valid,\n        _,\n        bbox_yaw,\n        _,\n        _,\n        _,\n        future_val,\n        predict,\n    ) in zip(\n        Agent_id,\n        XY,\n        current_valid,\n        VALID,\n        Speed,\n        Bbox_yaw,\n        Vyaw,\n        W,\n        L,\n        future_valid,\n        tracks_to_predict.flatten(),\n    ):\n\n        if (not validate and future_val.sum() == 0) or (validate and predict == 0):\n            continue\n        if current_val == 0:\n            continue\n\n        GLOBAL_IDX = -1\n        RES = []\n\n        xy_val = xy[valid > 0]\n        if len(xy_val) == 0:\n            continue\n\n        centered_xy = xy_val[-1].copy().reshape(-1, 2)\n\n        ANGLE = bbox_yaw[-1]\n\n        rot_matrix = np.array(\n            [\n                [np.cos(ANGLE), -np.sin(ANGLE)],\n                [np.sin(ANGLE), np.cos(ANGLE)],\n            ]\n        ).reshape(2, 2)\n\n        local_roadlines_state = ROADLINES_STATE.copy()\n\n        local_roadlines_state[:, :2] = (\n            local_roadlines_state[:, :2] - centered_xy\n        ) @ rot_matrix.astype(np.float64)\n\n        local_XY = ((XY - centered_xy).reshape(-1, 2) @ rot_matrix).reshape(\n            128, n_channels, 2\n        )\n\n        for (\n            other_agent_id,\n            other_agent_type,\n            other_xy,\n            other_valids,\n            other_speeds,\n            other_bbox_yaws,\n            other_v_yaws,\n            other_w,\n            other_l,\n            other_predict,\n        ) in zip(\n            Agent_id,\n            Agent_type,\n            local_XY,\n            VALID,\n            Speed,\n            Bbox_yaw,\n            Vyaw,\n            W.flatten(),\n            L.flatten(),\n            tracks_to_predict.flatten(),\n        ):\n            if other_valids.sum() == 0:\n                continue\n\n            GLOBAL_IDX += 1\n            for timestamp, (\n                (x, y),\n                v,\n                other_speed,\n                other_v_yaw,\n                other_bbox_yaw,\n            ) in enumerate(\n                zip(other_xy, other_valids, other_speeds, other_v_yaws, other_bbox_yaws)\n            ):\n                if v == 0:\n                    continue\n                tmp = np.zeros(48)\n                tmp[0] = x\n                tmp[1] = y\n                tmp[2] = other_speed\n                tmp[3] = other_v_yaw - ANGLE\n                tmp[4] = other_bbox_yaw - ANGLE\n                tmp[5] = float(other_l)\n                tmp[6] = float(other_w)\n\n                tmp[7:12] = ohe(5, other_agent_type, True)\n\n                tmp[43] = timestamp\n\n                tmp[44] = GLOBAL_IDX\n                tmp[45] = 1 if other_agent_id == agent_id else 0\n                tmp[46] = other_predict\n                tmp[47] = other_agent_id\n\n                RES.append(tmp)\n        local_roadlines_state[:, 44] = local_roadlines_state[:, 44] + GLOBAL_IDX + 1\n        RES = np.concatenate((make_2d(RES), local_roadlines_state), axis=0)\n        GRES.append(RES)\n\n    return GRES\n\n\ndef merge(\n    data, proc_id, validate, out_dir, use_vectorize=False, max_rand_int=10000000000\n):\n    parsed = tf.io.parse_single_example(data, features_description)\n    raster_data = rasterize(\n        parsed[\"state/tracks_to_predict\"].numpy(),\n        parsed[\"state/past/x\"].numpy(),\n        parsed[\"state/past/y\"].numpy(),\n        parsed[\"state/current/x\"].numpy(),\n        parsed[\"state/current/y\"].numpy(),\n        parsed[\"state/current/bbox_yaw\"].numpy(),\n        parsed[\"state/past/bbox_yaw\"].numpy(),\n        parsed[\"state/past/valid\"].numpy(),\n        parsed[\"state/current/valid\"].numpy(),\n        parsed[\"state/type\"].numpy(),\n        parsed[\"roadgraph_samples/xyz\"].numpy(),\n        parsed[\"roadgraph_samples/type\"].numpy(),\n        parsed[\"roadgraph_samples/valid\"].numpy(),\n        parsed[\"roadgraph_samples/id\"].numpy(),\n        parsed[\"state/current/width\"].numpy(),\n        parsed[\"state/current/length\"].numpy(),\n        parsed[\"state/id\"].numpy(),\n        parsed[\"traffic_light_state/current/state\"].numpy(),\n        parsed[\"traffic_light_state/current/id\"].numpy(),\n        parsed[\"traffic_light_state/current/valid\"].numpy(),\n        parsed[\"state/future/x\"].numpy(),\n        parsed[\"state/future/y\"].numpy(),\n        parsed[\"state/future/valid\"].numpy(),\n        parsed[\"scenario/id\"].numpy()[0].decode(\"utf-8\"),\n        validate=validate,\n    )\n\n    if use_vectorize:\n        vector_data = vectorize(\n            parsed[\"state/past/x\"].numpy(),\n            parsed[\"state/current/x\"].numpy(),\n            parsed[\"state/past/y\"].numpy(),\n            parsed[\"state/current/y\"].numpy(),\n            parsed[\"state/past/valid\"].numpy(),\n            parsed[\"state/current/valid\"].numpy(),\n            parsed[\"state/past/speed\"].numpy(),\n            parsed[\"state/current/speed\"].numpy(),\n            parsed[\"state/past/vel_yaw\"].numpy(),\n            parsed[\"state/current/vel_yaw\"].numpy(),\n            parsed[\"state/past/bbox_yaw\"].numpy(),\n            parsed[\"state/current/bbox_yaw\"].numpy(),\n            parsed[\"state/id\"].numpy(),\n            parsed[\"state/type\"].numpy(),\n            parsed[\"roadgraph_samples/id\"].numpy(),\n            parsed[\"roadgraph_samples/type\"].numpy(),\n            parsed[\"roadgraph_samples/valid\"].numpy(),\n            parsed[\"roadgraph_samples/xyz\"].numpy(),\n            parsed[\"traffic_light_state/current/id\"].numpy(),\n            parsed[\"traffic_light_state/current/state\"].numpy(),\n            parsed[\"traffic_light_state/current/valid\"].numpy(),\n            parsed[\"state/current/width\"].numpy(),\n            parsed[\"state/current/length\"].numpy(),\n            parsed[\"state/tracks_to_predict\"].numpy(),\n            parsed[\"state/future/valid\"].numpy(),\n            validate=validate,\n        )\n\n    for i in range(len(raster_data)):\n        if use_vectorize:\n            raster_data[i][\"vector_data\"] = vector_data[i].astype(np.float16)\n\n        r = np.random.randint(max_rand_int)\n        filename = f\"{idx2type[int(raster_data[i]['self_type'])]}_{proc_id}_{str(i).zfill(5)}_{r}.npz\"\n        np.savez_compressed(os.path.join(out_dir, filename), **raster_data[i])\n\n\ndef main():\n    args = parse_arguments()\n    print(args)\n\n    if not os.path.exists(args.out):\n        os.mkdir(args.out)\n\n    files = os.listdir(args.data)\n    dataset = tf.data.TFRecordDataset(\n        [os.path.join(args.data, f) for f in files], num_parallel_reads=1\n    )\n    if args.n_shards > 1:\n        dataset = dataset.shard(args.n_shards, args.each)\n\n    p = multiprocessing.Pool(args.n_jobs)\n    proc_id = 0\n    res = []\n    for data in tqdm(dataset.as_numpy_iterator()):\n        proc_id += 1\n        res.append(\n            p.apply_async(\n                merge,\n                kwds=dict(\n                    data=data,\n                    proc_id=proc_id,\n                    validate=not args.no_valid,\n                    out_dir=args.out,\n                    use_vectorize=args.use_vectorize,\n                ),\n            )\n        )\n\n    for r in tqdm(res):\n        r.get()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "requirements.txt",
    "content": "numpy\nopencv-python\ntensorflow\ntimm\ntorch\ntqdm\n"
  },
  {
    "path": "submission_proto/motion_submission_pb2.py",
    "content": "# -*- coding: utf-8 -*-\n# Generated by the protocol buffer compiler.  DO NOT EDIT!\n# source: motion_submission.proto\n\"\"\"Generated protocol buffer code.\"\"\"\nfrom google.protobuf import descriptor as _descriptor\nfrom google.protobuf import message as _message\nfrom google.protobuf import reflection as _reflection\nfrom google.protobuf import symbol_database as _symbol_database\n# @@protoc_insertion_point(imports)\n\n_sym_db = _symbol_database.Default()\n\n\n\n\nDESCRIPTOR = _descriptor.FileDescriptor(\n  name='motion_submission.proto',\n  package='waymo.open_dataset',\n  syntax='proto2',\n  serialized_options=None,\n  create_key=_descriptor._internal_create_key,\n  serialized_pb=b'\\n\\x17motion_submission.proto\\x12\\x12waymo.open_dataset\\\"8\\n\\nTrajectory\\x12\\x14\\n\\x08\\x63\\x65nter_x\\x18\\x02 \\x03(\\x02\\x42\\x02\\x10\\x01\\x12\\x14\\n\\x08\\x63\\x65nter_y\\x18\\x03 \\x03(\\x02\\x42\\x02\\x10\\x01\\\"Z\\n\\x10ScoredTrajectory\\x12\\x32\\n\\ntrajectory\\x18\\x01 \\x01(\\x0b\\x32\\x1e.waymo.open_dataset.Trajectory\\x12\\x12\\n\\nconfidence\\x18\\x02 \\x01(\\x02\\\"g\\n\\x16SingleObjectPrediction\\x12\\x11\\n\\tobject_id\\x18\\x01 \\x01(\\x05\\x12:\\n\\x0ctrajectories\\x18\\x02 \\x03(\\x0b\\x32$.waymo.open_dataset.ScoredTrajectory\\\"P\\n\\rPredictionSet\\x12?\\n\\x0bpredictions\\x18\\x01 \\x03(\\x0b\\x32*.waymo.open_dataset.SingleObjectPrediction\\\"Y\\n\\x10ObjectTrajectory\\x12\\x11\\n\\tobject_id\\x18\\x01 \\x01(\\x05\\x12\\x32\\n\\ntrajectory\\x18\\x02 \\x01(\\x0b\\x32\\x1e.waymo.open_dataset.Trajectory\\\"g\\n\\x15ScoredJointTrajectory\\x12:\\n\\x0ctrajectories\\x18\\x02 \\x03(\\x0b\\x32$.waymo.open_dataset.ObjectTrajectory\\x12\\x12\\n\\nconfidence\\x18\\x03 \\x01(\\x02\\\"X\\n\\x0fJointPrediction\\x12\\x45\\n\\x12joint_trajectories\\x18\\x01 \\x03(\\x0b\\x32).waymo.open_dataset.ScoredJointTrajectory\\\"\\xc7\\x01\\n\\x1c\\x43hallengeScenarioPredictions\\x12\\x13\\n\\x0bscenario_id\\x18\\x01 \\x01(\\t\\x12?\\n\\x12single_predictions\\x18\\x02 \\x01(\\x0b\\x32!.waymo.open_dataset.PredictionSetH\\x00\\x12?\\n\\x10joint_prediction\\x18\\x03 \\x01(\\x0b\\x32#.waymo.open_dataset.JointPredictionH\\x00\\x42\\x10\\n\\x0eprediction_set\\\"\\x96\\x03\\n\\x19MotionChallengeSubmission\\x12\\x14\\n\\x0c\\x61\\x63\\x63ount_name\\x18\\x03 \\x01(\\t\\x12\\x1a\\n\\x12unique_method_name\\x18\\x04 \\x01(\\t\\x12\\x0f\\n\\x07\\x61uthors\\x18\\x05 \\x03(\\t\\x12\\x13\\n\\x0b\\x61\\x66\\x66iliation\\x18\\x06 \\x01(\\t\\x12\\x13\\n\\x0b\\x64\\x65scription\\x18\\x07 \\x01(\\t\\x12\\x13\\n\\x0bmethod_link\\x18\\x08 \\x01(\\t\\x12U\\n\\x0fsubmission_type\\x18\\x02 \\x01(\\x0e\\x32<.waymo.open_dataset.MotionChallengeSubmission.SubmissionType\\x12N\\n\\x14scenario_predictions\\x18\\x01 \\x03(\\x0b\\x32\\x30.waymo.open_dataset.ChallengeScenarioPredictions\\\"P\\n\\x0eSubmissionType\\x12\\x0b\\n\\x07UNKNOWN\\x10\\x00\\x12\\x15\\n\\x11MOTION_PREDICTION\\x10\\x01\\x12\\x1a\\n\\x16INTERACTION_PREDICTION\\x10\\x02'\n)\n\n\n\n_MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE = _descriptor.EnumDescriptor(\n  name='SubmissionType',\n  full_name='waymo.open_dataset.MotionChallengeSubmission.SubmissionType',\n  filename=None,\n  file=DESCRIPTOR,\n  create_key=_descriptor._internal_create_key,\n  values=[\n    _descriptor.EnumValueDescriptor(\n      name='UNKNOWN', index=0, number=0,\n      serialized_options=None,\n      type=None,\n      create_key=_descriptor._internal_create_key),\n    _descriptor.EnumValueDescriptor(\n      name='MOTION_PREDICTION', index=1, number=1,\n      serialized_options=None,\n      type=None,\n      create_key=_descriptor._internal_create_key),\n    _descriptor.EnumValueDescriptor(\n      name='INTERACTION_PREDICTION', index=2, number=2,\n      serialized_options=None,\n      type=None,\n      create_key=_descriptor._internal_create_key),\n  ],\n  containing_type=None,\n  serialized_options=None,\n  serialized_start=1199,\n  serialized_end=1279,\n)\n_sym_db.RegisterEnumDescriptor(_MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE)\n\n\n_TRAJECTORY = _descriptor.Descriptor(\n  name='Trajectory',\n  full_name='waymo.open_dataset.Trajectory',\n  filename=None,\n  file=DESCRIPTOR,\n  containing_type=None,\n  create_key=_descriptor._internal_create_key,\n  fields=[\n    _descriptor.FieldDescriptor(\n      name='center_x', full_name='waymo.open_dataset.Trajectory.center_x', index=0,\n      number=2, type=2, cpp_type=6, label=3,\n      has_default_value=False, default_value=[],\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=b'\\020\\001', file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='center_y', full_name='waymo.open_dataset.Trajectory.center_y', index=1,\n      number=3, type=2, cpp_type=6, label=3,\n      has_default_value=False, default_value=[],\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=b'\\020\\001', file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n  ],\n  extensions=[\n  ],\n  nested_types=[],\n  enum_types=[\n  ],\n  serialized_options=None,\n  is_extendable=False,\n  syntax='proto2',\n  extension_ranges=[],\n  oneofs=[\n  ],\n  serialized_start=47,\n  serialized_end=103,\n)\n\n\n_SCOREDTRAJECTORY = _descriptor.Descriptor(\n  name='ScoredTrajectory',\n  full_name='waymo.open_dataset.ScoredTrajectory',\n  filename=None,\n  file=DESCRIPTOR,\n  containing_type=None,\n  create_key=_descriptor._internal_create_key,\n  fields=[\n    _descriptor.FieldDescriptor(\n      name='trajectory', full_name='waymo.open_dataset.ScoredTrajectory.trajectory', index=0,\n      number=1, type=11, cpp_type=10, label=1,\n      has_default_value=False, default_value=None,\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='confidence', full_name='waymo.open_dataset.ScoredTrajectory.confidence', index=1,\n      number=2, type=2, cpp_type=6, label=1,\n      has_default_value=False, default_value=float(0),\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n  ],\n  extensions=[\n  ],\n  nested_types=[],\n  enum_types=[\n  ],\n  serialized_options=None,\n  is_extendable=False,\n  syntax='proto2',\n  extension_ranges=[],\n  oneofs=[\n  ],\n  serialized_start=105,\n  serialized_end=195,\n)\n\n\n_SINGLEOBJECTPREDICTION = _descriptor.Descriptor(\n  name='SingleObjectPrediction',\n  full_name='waymo.open_dataset.SingleObjectPrediction',\n  filename=None,\n  file=DESCRIPTOR,\n  containing_type=None,\n  create_key=_descriptor._internal_create_key,\n  fields=[\n    _descriptor.FieldDescriptor(\n      name='object_id', full_name='waymo.open_dataset.SingleObjectPrediction.object_id', index=0,\n      number=1, type=5, cpp_type=1, label=1,\n      has_default_value=False, default_value=0,\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='trajectories', full_name='waymo.open_dataset.SingleObjectPrediction.trajectories', index=1,\n      number=2, type=11, cpp_type=10, label=3,\n      has_default_value=False, default_value=[],\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n  ],\n  extensions=[\n  ],\n  nested_types=[],\n  enum_types=[\n  ],\n  serialized_options=None,\n  is_extendable=False,\n  syntax='proto2',\n  extension_ranges=[],\n  oneofs=[\n  ],\n  serialized_start=197,\n  serialized_end=300,\n)\n\n\n_PREDICTIONSET = _descriptor.Descriptor(\n  name='PredictionSet',\n  full_name='waymo.open_dataset.PredictionSet',\n  filename=None,\n  file=DESCRIPTOR,\n  containing_type=None,\n  create_key=_descriptor._internal_create_key,\n  fields=[\n    _descriptor.FieldDescriptor(\n      name='predictions', full_name='waymo.open_dataset.PredictionSet.predictions', index=0,\n      number=1, type=11, cpp_type=10, label=3,\n      has_default_value=False, default_value=[],\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n  ],\n  extensions=[\n  ],\n  nested_types=[],\n  enum_types=[\n  ],\n  serialized_options=None,\n  is_extendable=False,\n  syntax='proto2',\n  extension_ranges=[],\n  oneofs=[\n  ],\n  serialized_start=302,\n  serialized_end=382,\n)\n\n\n_OBJECTTRAJECTORY = _descriptor.Descriptor(\n  name='ObjectTrajectory',\n  full_name='waymo.open_dataset.ObjectTrajectory',\n  filename=None,\n  file=DESCRIPTOR,\n  containing_type=None,\n  create_key=_descriptor._internal_create_key,\n  fields=[\n    _descriptor.FieldDescriptor(\n      name='object_id', full_name='waymo.open_dataset.ObjectTrajectory.object_id', index=0,\n      number=1, type=5, cpp_type=1, label=1,\n      has_default_value=False, default_value=0,\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='trajectory', full_name='waymo.open_dataset.ObjectTrajectory.trajectory', index=1,\n      number=2, type=11, cpp_type=10, label=1,\n      has_default_value=False, default_value=None,\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n  ],\n  extensions=[\n  ],\n  nested_types=[],\n  enum_types=[\n  ],\n  serialized_options=None,\n  is_extendable=False,\n  syntax='proto2',\n  extension_ranges=[],\n  oneofs=[\n  ],\n  serialized_start=384,\n  serialized_end=473,\n)\n\n\n_SCOREDJOINTTRAJECTORY = _descriptor.Descriptor(\n  name='ScoredJointTrajectory',\n  full_name='waymo.open_dataset.ScoredJointTrajectory',\n  filename=None,\n  file=DESCRIPTOR,\n  containing_type=None,\n  create_key=_descriptor._internal_create_key,\n  fields=[\n    _descriptor.FieldDescriptor(\n      name='trajectories', full_name='waymo.open_dataset.ScoredJointTrajectory.trajectories', index=0,\n      number=2, type=11, cpp_type=10, label=3,\n      has_default_value=False, default_value=[],\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='confidence', full_name='waymo.open_dataset.ScoredJointTrajectory.confidence', index=1,\n      number=3, type=2, cpp_type=6, label=1,\n      has_default_value=False, default_value=float(0),\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n  ],\n  extensions=[\n  ],\n  nested_types=[],\n  enum_types=[\n  ],\n  serialized_options=None,\n  is_extendable=False,\n  syntax='proto2',\n  extension_ranges=[],\n  oneofs=[\n  ],\n  serialized_start=475,\n  serialized_end=578,\n)\n\n\n_JOINTPREDICTION = _descriptor.Descriptor(\n  name='JointPrediction',\n  full_name='waymo.open_dataset.JointPrediction',\n  filename=None,\n  file=DESCRIPTOR,\n  containing_type=None,\n  create_key=_descriptor._internal_create_key,\n  fields=[\n    _descriptor.FieldDescriptor(\n      name='joint_trajectories', full_name='waymo.open_dataset.JointPrediction.joint_trajectories', index=0,\n      number=1, type=11, cpp_type=10, label=3,\n      has_default_value=False, default_value=[],\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n  ],\n  extensions=[\n  ],\n  nested_types=[],\n  enum_types=[\n  ],\n  serialized_options=None,\n  is_extendable=False,\n  syntax='proto2',\n  extension_ranges=[],\n  oneofs=[\n  ],\n  serialized_start=580,\n  serialized_end=668,\n)\n\n\n_CHALLENGESCENARIOPREDICTIONS = _descriptor.Descriptor(\n  name='ChallengeScenarioPredictions',\n  full_name='waymo.open_dataset.ChallengeScenarioPredictions',\n  filename=None,\n  file=DESCRIPTOR,\n  containing_type=None,\n  create_key=_descriptor._internal_create_key,\n  fields=[\n    _descriptor.FieldDescriptor(\n      name='scenario_id', full_name='waymo.open_dataset.ChallengeScenarioPredictions.scenario_id', index=0,\n      number=1, type=9, cpp_type=9, label=1,\n      has_default_value=False, default_value=b\"\".decode('utf-8'),\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='single_predictions', full_name='waymo.open_dataset.ChallengeScenarioPredictions.single_predictions', index=1,\n      number=2, type=11, cpp_type=10, label=1,\n      has_default_value=False, default_value=None,\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='joint_prediction', full_name='waymo.open_dataset.ChallengeScenarioPredictions.joint_prediction', index=2,\n      number=3, type=11, cpp_type=10, label=1,\n      has_default_value=False, default_value=None,\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n  ],\n  extensions=[\n  ],\n  nested_types=[],\n  enum_types=[\n  ],\n  serialized_options=None,\n  is_extendable=False,\n  syntax='proto2',\n  extension_ranges=[],\n  oneofs=[\n    _descriptor.OneofDescriptor(\n      name='prediction_set', full_name='waymo.open_dataset.ChallengeScenarioPredictions.prediction_set',\n      index=0, containing_type=None,\n      create_key=_descriptor._internal_create_key,\n    fields=[]),\n  ],\n  serialized_start=671,\n  serialized_end=870,\n)\n\n\n_MOTIONCHALLENGESUBMISSION = _descriptor.Descriptor(\n  name='MotionChallengeSubmission',\n  full_name='waymo.open_dataset.MotionChallengeSubmission',\n  filename=None,\n  file=DESCRIPTOR,\n  containing_type=None,\n  create_key=_descriptor._internal_create_key,\n  fields=[\n    _descriptor.FieldDescriptor(\n      name='account_name', full_name='waymo.open_dataset.MotionChallengeSubmission.account_name', index=0,\n      number=3, type=9, cpp_type=9, label=1,\n      has_default_value=False, default_value=b\"\".decode('utf-8'),\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='unique_method_name', full_name='waymo.open_dataset.MotionChallengeSubmission.unique_method_name', index=1,\n      number=4, type=9, cpp_type=9, label=1,\n      has_default_value=False, default_value=b\"\".decode('utf-8'),\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='authors', full_name='waymo.open_dataset.MotionChallengeSubmission.authors', index=2,\n      number=5, type=9, cpp_type=9, label=3,\n      has_default_value=False, default_value=[],\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='affiliation', full_name='waymo.open_dataset.MotionChallengeSubmission.affiliation', index=3,\n      number=6, type=9, cpp_type=9, label=1,\n      has_default_value=False, default_value=b\"\".decode('utf-8'),\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='description', full_name='waymo.open_dataset.MotionChallengeSubmission.description', index=4,\n      number=7, type=9, cpp_type=9, label=1,\n      has_default_value=False, default_value=b\"\".decode('utf-8'),\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='method_link', full_name='waymo.open_dataset.MotionChallengeSubmission.method_link', index=5,\n      number=8, type=9, cpp_type=9, label=1,\n      has_default_value=False, default_value=b\"\".decode('utf-8'),\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='submission_type', full_name='waymo.open_dataset.MotionChallengeSubmission.submission_type', index=6,\n      number=2, type=14, cpp_type=8, label=1,\n      has_default_value=False, default_value=0,\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n    _descriptor.FieldDescriptor(\n      name='scenario_predictions', full_name='waymo.open_dataset.MotionChallengeSubmission.scenario_predictions', index=7,\n      number=1, type=11, cpp_type=10, label=3,\n      has_default_value=False, default_value=[],\n      message_type=None, enum_type=None, containing_type=None,\n      is_extension=False, extension_scope=None,\n      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),\n  ],\n  extensions=[\n  ],\n  nested_types=[],\n  enum_types=[\n    _MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE,\n  ],\n  serialized_options=None,\n  is_extendable=False,\n  syntax='proto2',\n  extension_ranges=[],\n  oneofs=[\n  ],\n  serialized_start=873,\n  serialized_end=1279,\n)\n\n_SCOREDTRAJECTORY.fields_by_name['trajectory'].message_type = _TRAJECTORY\n_SINGLEOBJECTPREDICTION.fields_by_name['trajectories'].message_type = _SCOREDTRAJECTORY\n_PREDICTIONSET.fields_by_name['predictions'].message_type = _SINGLEOBJECTPREDICTION\n_OBJECTTRAJECTORY.fields_by_name['trajectory'].message_type = _TRAJECTORY\n_SCOREDJOINTTRAJECTORY.fields_by_name['trajectories'].message_type = _OBJECTTRAJECTORY\n_JOINTPREDICTION.fields_by_name['joint_trajectories'].message_type = _SCOREDJOINTTRAJECTORY\n_CHALLENGESCENARIOPREDICTIONS.fields_by_name['single_predictions'].message_type = _PREDICTIONSET\n_CHALLENGESCENARIOPREDICTIONS.fields_by_name['joint_prediction'].message_type = _JOINTPREDICTION\n_CHALLENGESCENARIOPREDICTIONS.oneofs_by_name['prediction_set'].fields.append(\n  _CHALLENGESCENARIOPREDICTIONS.fields_by_name['single_predictions'])\n_CHALLENGESCENARIOPREDICTIONS.fields_by_name['single_predictions'].containing_oneof = _CHALLENGESCENARIOPREDICTIONS.oneofs_by_name['prediction_set']\n_CHALLENGESCENARIOPREDICTIONS.oneofs_by_name['prediction_set'].fields.append(\n  _CHALLENGESCENARIOPREDICTIONS.fields_by_name['joint_prediction'])\n_CHALLENGESCENARIOPREDICTIONS.fields_by_name['joint_prediction'].containing_oneof = _CHALLENGESCENARIOPREDICTIONS.oneofs_by_name['prediction_set']\n_MOTIONCHALLENGESUBMISSION.fields_by_name['submission_type'].enum_type = _MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE\n_MOTIONCHALLENGESUBMISSION.fields_by_name['scenario_predictions'].message_type = _CHALLENGESCENARIOPREDICTIONS\n_MOTIONCHALLENGESUBMISSION_SUBMISSIONTYPE.containing_type = _MOTIONCHALLENGESUBMISSION\nDESCRIPTOR.message_types_by_name['Trajectory'] = _TRAJECTORY\nDESCRIPTOR.message_types_by_name['ScoredTrajectory'] = _SCOREDTRAJECTORY\nDESCRIPTOR.message_types_by_name['SingleObjectPrediction'] = _SINGLEOBJECTPREDICTION\nDESCRIPTOR.message_types_by_name['PredictionSet'] = _PREDICTIONSET\nDESCRIPTOR.message_types_by_name['ObjectTrajectory'] = _OBJECTTRAJECTORY\nDESCRIPTOR.message_types_by_name['ScoredJointTrajectory'] = _SCOREDJOINTTRAJECTORY\nDESCRIPTOR.message_types_by_name['JointPrediction'] = _JOINTPREDICTION\nDESCRIPTOR.message_types_by_name['ChallengeScenarioPredictions'] = _CHALLENGESCENARIOPREDICTIONS\nDESCRIPTOR.message_types_by_name['MotionChallengeSubmission'] = _MOTIONCHALLENGESUBMISSION\n_sym_db.RegisterFileDescriptor(DESCRIPTOR)\n\nTrajectory = _reflection.GeneratedProtocolMessageType('Trajectory', (_message.Message,), {\n  'DESCRIPTOR' : _TRAJECTORY,\n  '__module__' : 'motion_submission_pb2'\n  # @@protoc_insertion_point(class_scope:waymo.open_dataset.Trajectory)\n  })\n_sym_db.RegisterMessage(Trajectory)\n\nScoredTrajectory = _reflection.GeneratedProtocolMessageType('ScoredTrajectory', (_message.Message,), {\n  'DESCRIPTOR' : _SCOREDTRAJECTORY,\n  '__module__' : 'motion_submission_pb2'\n  # @@protoc_insertion_point(class_scope:waymo.open_dataset.ScoredTrajectory)\n  })\n_sym_db.RegisterMessage(ScoredTrajectory)\n\nSingleObjectPrediction = _reflection.GeneratedProtocolMessageType('SingleObjectPrediction', (_message.Message,), {\n  'DESCRIPTOR' : _SINGLEOBJECTPREDICTION,\n  '__module__' : 'motion_submission_pb2'\n  # @@protoc_insertion_point(class_scope:waymo.open_dataset.SingleObjectPrediction)\n  })\n_sym_db.RegisterMessage(SingleObjectPrediction)\n\nPredictionSet = _reflection.GeneratedProtocolMessageType('PredictionSet', (_message.Message,), {\n  'DESCRIPTOR' : _PREDICTIONSET,\n  '__module__' : 'motion_submission_pb2'\n  # @@protoc_insertion_point(class_scope:waymo.open_dataset.PredictionSet)\n  })\n_sym_db.RegisterMessage(PredictionSet)\n\nObjectTrajectory = _reflection.GeneratedProtocolMessageType('ObjectTrajectory', (_message.Message,), {\n  'DESCRIPTOR' : _OBJECTTRAJECTORY,\n  '__module__' : 'motion_submission_pb2'\n  # @@protoc_insertion_point(class_scope:waymo.open_dataset.ObjectTrajectory)\n  })\n_sym_db.RegisterMessage(ObjectTrajectory)\n\nScoredJointTrajectory = _reflection.GeneratedProtocolMessageType('ScoredJointTrajectory', (_message.Message,), {\n  'DESCRIPTOR' : _SCOREDJOINTTRAJECTORY,\n  '__module__' : 'motion_submission_pb2'\n  # @@protoc_insertion_point(class_scope:waymo.open_dataset.ScoredJointTrajectory)\n  })\n_sym_db.RegisterMessage(ScoredJointTrajectory)\n\nJointPrediction = _reflection.GeneratedProtocolMessageType('JointPrediction', (_message.Message,), {\n  'DESCRIPTOR' : _JOINTPREDICTION,\n  '__module__' : 'motion_submission_pb2'\n  # @@protoc_insertion_point(class_scope:waymo.open_dataset.JointPrediction)\n  })\n_sym_db.RegisterMessage(JointPrediction)\n\nChallengeScenarioPredictions = _reflection.GeneratedProtocolMessageType('ChallengeScenarioPredictions', (_message.Message,), {\n  'DESCRIPTOR' : _CHALLENGESCENARIOPREDICTIONS,\n  '__module__' : 'motion_submission_pb2'\n  # @@protoc_insertion_point(class_scope:waymo.open_dataset.ChallengeScenarioPredictions)\n  })\n_sym_db.RegisterMessage(ChallengeScenarioPredictions)\n\nMotionChallengeSubmission = _reflection.GeneratedProtocolMessageType('MotionChallengeSubmission', (_message.Message,), {\n  'DESCRIPTOR' : _MOTIONCHALLENGESUBMISSION,\n  '__module__' : 'motion_submission_pb2'\n  # @@protoc_insertion_point(class_scope:waymo.open_dataset.MotionChallengeSubmission)\n  })\n_sym_db.RegisterMessage(MotionChallengeSubmission)\n\n\n_TRAJECTORY.fields_by_name['center_x']._options = None\n_TRAJECTORY.fields_by_name['center_y']._options = None\n# @@protoc_insertion_point(module_scope)\n"
  },
  {
    "path": "submit.py",
    "content": "import argparse\n\n# chage this if you have problem\nimport sys\nsys.path.insert(1, \"~/.local/lib/python3.6/site-packages\")\n\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader\nfrom tqdm import tqdm\n\nfrom submission_proto import motion_submission_pb2\nfrom train import WaymoLoader, Model\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--test-data\", type=str, required=True, help=\"Path to rasterized data\"\n    )\n    parser.add_argument(\n        \"--model-path\", type=str, required=True, help=\"Path to CNN model\"\n    )\n    parser.add_argument(\n        \"--time-limit\", type=int, required=False, default=80, help=\"Number time steps\"\n    )\n    parser.add_argument(\n        \"--save\", type=str, required=True, help=\"Path to save predictions\"\n    )\n    parser.add_argument(\n        \"--model-name\", type=str, required=False, help=\"Model name\"\n    )\n\n    parser.add_argument(\"--account-name\", required=False, default=\"\")\n    parser.add_argument(\"--authors\", required=False, default=\"\")\n    parser.add_argument(\"--method-name\", required=False, default=\"SimpleCNNOnRaster\")\n\n    parser.add_argument(\"--batch-size\", type=int, required=False, default=128)\n\n    args = parser.parse_args()\n\n    return args\n\n\ndef main():\n    args = parse_args()\n    print(args)\n\n    if args.model_path.endswith(\"pth\"):\n        model = Model(args.model_name)\n        model.load_state_dict(torch.load(args.model_path)[\"model_state_dict\"])\n    else:\n        model = torch.jit.load(args.model_path)\n\n    model.cuda()\n    model.eval()\n\n    dataset = WaymoLoader(args.test_data, is_test=True)\n    loader = DataLoader(\n        dataset, batch_size=args.batch_size, num_workers=min(args.batch_size, 16)\n    )\n\n    RES = {}\n    with torch.no_grad():\n        for x, center, yaw, agent_id, scenario_id, _, _ in tqdm(loader):\n            x = x.cuda()\n            confidences_logits, logits = model(x)\n            confidences = torch.softmax(confidences_logits, dim=1)\n\n            logits = logits.cpu().numpy()\n            confidences = confidences.cpu().numpy()\n            agent_id = agent_id.cpu().numpy()\n            center = center.cpu().numpy()\n            yaw = yaw.cpu().numpy()\n            for p, conf, aid, sid, c, y in zip(\n                logits, confidences, agent_id, scenario_id, center, yaw\n            ):\n                if sid not in RES:\n                    RES[sid] = []\n\n                RES[sid].append(\n                    {\"aid\": aid, \"conf\": conf, \"pred\": p, \"yaw\": -y, \"center\": c}\n                )\n\n    motion_challenge_submission = motion_submission_pb2.MotionChallengeSubmission()\n    motion_challenge_submission.account_name = args.account_name\n    motion_challenge_submission.authors.extend(args.authors.split(\",\"))\n    motion_challenge_submission.submission_type = (\n        motion_submission_pb2.MotionChallengeSubmission.SubmissionType.MOTION_PREDICTION\n    )\n    motion_challenge_submission.unique_method_name = args.method_name\n\n    selector = np.arange(4, args.time_limit + 1, 5)\n    for scenario_id, data in tqdm(RES.items()):\n        scenario_predictions = motion_challenge_submission.scenario_predictions.add()\n        scenario_predictions.scenario_id = scenario_id\n        prediction_set = scenario_predictions.single_predictions\n\n        for d in data:\n            predictions = prediction_set.predictions.add()\n            predictions.object_id = int(d[\"aid\"])\n\n            y = d[\"yaw\"]\n            rot_matrix = np.array([\n                [np.cos(y), -np.sin(y)],\n                [np.sin(y), np.cos(y)],\n            ])\n\n            for i in np.argsort(-d[\"conf\"]):\n                scored_trajectory = predictions.trajectories.add()\n                scored_trajectory.confidence = d[\"conf\"][i]\n\n                trajectory = scored_trajectory.trajectory\n\n                p = d[\"pred\"][i][selector] @ rot_matrix + d[\"center\"]\n\n                trajectory.center_x.extend(p[:, 0])\n                trajectory.center_y.extend(p[:, 1])\n\n    with open(args.save, \"wb\") as f:\n        f.write(motion_challenge_submission.SerializeToString())\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "train.py",
    "content": "import argparse\nimport os\n\nimport numpy as np\nimport timm\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader, Dataset\nfrom torch.utils.tensorboard import SummaryWriter\nfrom tqdm import tqdm\n\nIMG_RES = 224\nIN_CHANNELS = 25\nTL = 80\nN_TRAJS = 6\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--train-data\", type=str, required=True, help=\"Path to rasterized data\"\n    )\n    parser.add_argument(\n        \"--dev-data\", type=str, required=True, help=\"Path to rasterized data\"\n    )\n    parser.add_argument(\n        \"--img-res\",\n        type=int,\n        required=False,\n        default=IMG_RES,\n        help=\"Input images resolution\",\n    )\n    parser.add_argument(\n        \"--in-channels\",\n        type=int,\n        required=False,\n        default=IN_CHANNELS,\n        help=\"Input raster channels\",\n    )\n    parser.add_argument(\n        \"--time-limit\",\n        type=int,\n        required=False,\n        default=TL,\n        help=\"Number time step to predict\",\n    )\n    parser.add_argument(\n        \"--n-traj\",\n        type=int,\n        required=False,\n        default=N_TRAJS,\n        help=\"Number of trajectories to predict\",\n    )\n    parser.add_argument(\n        \"--save\", type=str, required=True, help=\"Path to save model and logs\"\n    )\n\n    parser.add_argument(\n        \"--model\", type=str, required=False, default=\"xception71\", help=\"CNN model name\"\n    )\n    parser.add_argument(\"--lr\", type=float, required=False, default=1e-3)\n    parser.add_argument(\"--batch-size\", type=int, required=False, default=48)\n    parser.add_argument(\"--n-epochs\", type=int, required=False, default=60)\n\n    parser.add_argument(\"--valid-limit\", type=int, required=False, default=24 * 100)\n    parser.add_argument(\n        \"--n-monitor-train\",\n        type=int,\n        required=False,\n        default=10,\n        help=\"Validate model each `n-validate` steps\",\n    )\n    parser.add_argument(\n        \"--n-monitor-validate\",\n        type=int,\n        required=False,\n        default=1000,\n        help=\"Validate model each `n-validate` steps\",\n    )\n\n    args = parser.parse_args()\n\n    return args\n\n\nclass Model(nn.Module):\n    def __init__(\n        self, model_name, in_channels=IN_CHANNELS, time_limit=TL, n_traj=N_TRAJS\n    ):\n        super().__init__()\n\n        self.n_traj = n_traj\n        self.time_limit = time_limit\n        self.model = timm.create_model(\n            model_name,\n            pretrained=True,\n            in_chans=in_channels,\n            num_classes=self.n_traj * 2 * self.time_limit + self.n_traj,\n        )\n\n\n    def forward(self, x):\n        outputs = self.model(x)\n\n        confidences_logits, logits = (\n            outputs[:, : self.n_traj],\n            outputs[:, self.n_traj :],\n        )\n        logits = logits.view(-1, self.n_traj, self.time_limit, 2)\n\n        return confidences_logits, logits\n\n\ndef pytorch_neg_multi_log_likelihood_batch(gt, logits, confidences, avails):\n    \"\"\"\n    Compute a negative log-likelihood for the multi-modal scenario.\n    Args:\n        gt (Tensor): array of shape (bs)x(time)x(2D coords)\n        logits (Tensor): array of shape (bs)x(modes)x(time)x(2D coords)\n        confidences (Tensor): array of shape (bs)x(modes) with a confidence for each mode in each sample\n        avails (Tensor): array of shape (bs)x(time) with the availability for each gt timestep\n    Returns:\n        Tensor: negative log-likelihood for this example, a single float number\n    \"\"\"\n\n    # convert to (batch_size, num_modes, future_len, num_coords)\n    gt = torch.unsqueeze(gt, 1)  # add modes\n    avails = avails[:, None, :, None]  # add modes and cords\n\n    # error (batch_size, num_modes, future_len)\n    error = torch.sum(\n        ((gt - logits) * avails) ** 2, dim=-1\n    )  # reduce coords and use availability\n\n    with np.errstate(\n        divide=\"ignore\"\n    ):  # when confidence is 0 log goes to -inf, but we're fine with it\n        # error (batch_size, num_modes)\n        error = nn.functional.log_softmax(confidences, dim=1) - 0.5 * torch.sum(\n            error, dim=-1\n        )  # reduce time\n\n    # error (batch_size, num_modes)\n    error = -torch.logsumexp(error, dim=-1, keepdim=True)\n\n    return torch.mean(error)\n\n\nclass WaymoLoader(Dataset):\n    def __init__(self, directory, limit=0, return_vector=False, is_test=False):\n        files = os.listdir(directory)\n        self.files = [os.path.join(directory, f) for f in files if f.endswith(\".npz\")]\n\n        if limit > 0:\n            self.files = self.files[:limit]\n        else:\n            self.files = sorted(self.files)\n\n        self.return_vector = return_vector\n        self.is_test = is_test\n\n    def __len__(self):\n        return len(self.files)\n\n    def __getitem__(self, idx):\n        filename = self.files[idx]\n        data = np.load(filename, allow_pickle=True)\n\n        raster = data[\"raster\"].astype(\"float32\")\n        raster = raster.transpose(2, 1, 0) / 255\n\n        if self.is_test:\n            center = data[\"shift\"]\n            yaw = data[\"yaw\"]\n            agent_id = data[\"object_id\"]\n            scenario_id = data[\"scenario_id\"]\n\n            return (\n                raster,\n                center,\n                yaw,\n                agent_id,\n                str(scenario_id),\n                data[\"_gt_marginal\"],\n                data[\"gt_marginal\"],\n            )\n\n        trajectory = data[\"gt_marginal\"]\n\n        is_available = data[\"future_val_marginal\"]\n\n        if self.return_vector:\n            return raster, trajectory, is_available, data[\"vector_data\"]\n\n        return raster, trajectory, is_available\n\n\ndef main():\n    args = parse_args()\n\n    summary_writer = SummaryWriter(os.path.join(args.save, \"logs\"))\n\n    train_path = args.train_data\n    dev_path = args.dev_data\n    path_to_save = args.save\n    if not os.path.exists(path_to_save):\n        os.mkdir(path_to_save)\n\n    dataset = WaymoLoader(train_path)\n\n    batch_size = args.batch_size\n    num_workers = min(16, batch_size)\n    dataloader = DataLoader(\n        dataset,\n        batch_size=batch_size,\n        shuffle=True,\n        num_workers=num_workers,\n        pin_memory=False,\n        persistent_workers=True,\n    )\n\n    val_dataset = WaymoLoader(dev_path, limit=args.valid_limit)\n    val_dataloader = DataLoader(\n        val_dataset,\n        batch_size=batch_size * 2,\n        shuffle=False,\n        num_workers=num_workers,\n        pin_memory=False,\n        persistent_workers=True,\n    )\n\n    model_name = args.model\n    time_limit = args.time_limit\n    n_traj = args.n_traj\n    model = Model(\n        model_name, in_channels=args.in_channels, time_limit=time_limit, n_traj=n_traj\n    )\n    model.cuda()\n\n    lr = args.lr\n    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n\n    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n        optimizer,\n        T_0=2 * len(dataloader),\n        T_mult=1,\n        eta_min=max(1e-2 * lr, 1e-6),\n        last_epoch=-1,\n    )\n\n    start_iter = 0\n    best_loss = float(\"+inf\")\n    glosses = []\n\n    tr_it = iter(dataloader)\n    n_epochs = args.n_epochs\n    progress_bar = tqdm(range(start_iter, len(dataloader) * n_epochs))\n\n    saver = lambda name: torch.save(\n        {\n            \"score\": best_loss,\n            \"iteration\": iteration,\n            \"model_state_dict\": model.state_dict(),\n            \"optimizer_state_dict\": optimizer.state_dict(),\n            \"scheduler_state_dict\": scheduler.state_dict(),\n            \"loss\": loss.item(),\n        },\n        os.path.join(path_to_save, name),\n    )\n\n    for iteration in progress_bar:\n        model.train()\n        try:\n            x, y, is_available = next(tr_it)\n        except StopIteration:\n            tr_it = iter(dataloader)\n            x, y, is_available = next(tr_it)\n\n        x, y, is_available = map(lambda x: x.cuda(), (x, y, is_available))\n\n        optimizer.zero_grad()\n\n        confidences_logits, logits = model(x)\n\n        loss = pytorch_neg_multi_log_likelihood_batch(\n            y, logits, confidences_logits, is_available\n        )\n        loss.backward()\n        optimizer.step()\n        scheduler.step()\n\n        glosses.append(loss.item())\n        if (iteration + 1) % args.n_monitor_train == 0:\n            progress_bar.set_description(\n                f\"loss: {loss.item():.3}\"\n                f\" avg: {np.mean(glosses[-100:]):.2}\"\n                f\" {scheduler.get_last_lr()[-1]:.3}\"\n            )\n            summary_writer.add_scalar(\"train/loss\", loss.item(), iteration)\n            summary_writer.add_scalar(\"lr\", scheduler.get_last_lr()[-1], iteration)\n\n        if (iteration + 1) % args.n_monitor_validate == 0:\n            optimizer.zero_grad()\n            model.eval()\n            with torch.no_grad():\n                val_losses = []\n                for x, y, is_available in val_dataloader:\n                    x, y, is_available = map(lambda x: x.cuda(), (x, y, is_available))\n\n                    confidences_logits, logits = model(x)\n                    loss = pytorch_neg_multi_log_likelihood_batch(\n                        y, logits, confidences_logits, is_available\n                    )\n                    val_losses.append(loss.item())\n\n                summary_writer.add_scalar(\"dev/loss\", np.mean(val_losses), iteration)\n\n            saver(\"model_last.pth\")\n\n            mean_val_loss = np.mean(val_losses)\n            if mean_val_loss < best_loss:\n                best_loss = mean_val_loss\n                saver(\"model_best.pth\")\n\n                model.eval()\n                with torch.no_grad():\n                    traced_model = torch.jit.trace(\n                        model,\n                        torch.rand(\n                            1, args.in_channels, args.img_res, args.img_res\n                        ).cuda(),\n                    )\n\n                traced_model.save(os.path.join(path_to_save, \"model_best.pt\"))\n                del traced_model\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "visualize.py",
    "content": "import argparse\nimport os\n\nimport numpy as np\nimport torch\nfrom matplotlib import pyplot as plt\nfrom matplotlib.pyplot import figure\nfrom torch.utils.data import DataLoader\n\nfrom train import WaymoLoader, pytorch_neg_multi_log_likelihood_batch\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--data\", type=str, required=True)\n    parser.add_argument(\"--save\", type=str, required=True)\n    parser.add_argument(\"--n-samples\", type=int, required=False, default=50)\n    parser.add_argument(\"--use-top1\", action=\"store_true\")\n\n    args = parser.parse_args()\n\n    return args\n\n\ndef main():\n    args = parse_args()\n    if not os.path.exists(args.save):\n        os.mkdir(args.save)\n\n    model = torch.jit.load(args.model).cuda().eval()\n    loader = DataLoader(\n        WaymoLoader(args.data, return_vector=True),\n        batch_size=1,\n        num_workers=1,\n        shuffle=False,\n    )\n\n    iii = 0\n    with torch.no_grad():\n        for x, y, is_available, vector_data in loader:\n            x, y, is_available = map(lambda x: x.cuda(), (x, y, is_available))\n\n            confidences_logits, logits = model(x)\n\n            argmax = confidences_logits.argmax()\n            if args.use_top1:\n                confidences_logits = confidences_logits[:, argmax].unsqueeze(1)\n                logits = logits[:, argmax].unsqueeze(1)\n\n            loss = pytorch_neg_multi_log_likelihood_batch(\n                y, logits, confidences_logits, is_available\n            )\n            confidences = torch.softmax(confidences_logits, dim=1)\n            V = vector_data[0]\n\n            X, idx = V[:, :44], V[:, 44].flatten()\n\n            figure(figsize=(15, 15), dpi=80)\n            for i in np.unique(idx):\n                _X = X[idx == i]\n                if _X[:, 5:12].sum() > 0:\n                    plt.plot(_X[:, 0], _X[:, 1], linewidth=4, color=\"red\")\n                else:\n                    plt.plot(_X[:, 0], _X[:, 1], color=\"black\")\n                plt.xlim([-224 // 4, 224 // 4])\n                plt.ylim([-224 // 4, 224 // 4])\n\n            logits = logits.squeeze(0).cpu().numpy()\n            y = y.squeeze(0).cpu().numpy()\n            is_available = is_available.squeeze(0).long().cpu().numpy()\n            confidences = confidences.squeeze(0).cpu().numpy()\n            plt.plot(\n                y[is_available > 0][::10, 0],\n                y[is_available > 0][::10, 1],\n                \"-o\",\n                label=\"gt\",\n            )\n\n            plt.plot(\n                logits[confidences.argmax()][is_available > 0][::10, 0],\n                logits[confidences.argmax()][is_available > 0][::10, 1],\n                \"-o\",\n                label=\"pred top 1\",\n            )\n            if not args.use_top1:\n                for traj_id in range(len(logits)):\n                    if traj_id == argmax:\n                        continue\n\n                    alpha = confidences[traj_id].item()\n                    plt.plot(\n                        logits[traj_id][is_available > 0][::10, 0],\n                        logits[traj_id][is_available > 0][::10, 1],\n                        \"-o\",\n                        label=f\"pred {traj_id} {alpha:.3f}\",\n                        alpha=alpha,\n                    )\n\n\n            plt.title(loss.item())\n            plt.legend()\n            plt.savefig(\n                os.path.join(args.save, f\"{iii:0>2}_{loss.item():.3f}.png\")\n            )\n            plt.close()\n\n            iii += 1\n            if iii == args.n_samples:\n                break\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]