Full Code of salesforce/ALPRO for AI

main d21173f55a73 cached
62 files
462.2 KB
112.4k tokens
509 symbols
1 requests
Download .txt
Showing preview only (484K chars total). Download the full file or copy to clipboard to get everything.
Repository: salesforce/ALPRO
Branch: main
Commit: d21173f55a73
Files: 62
Total size: 462.2 KB

Directory structure:
gitextract_egeq4lt4/

├── .gitignore
├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING-ARCHIVED.md
├── LICENSE
├── README.md
├── SECURITY.md
├── config_release/
│   ├── base_model.json
│   ├── didemo_ret.json
│   ├── msrvtt_qa.json
│   ├── msrvtt_ret.json
│   ├── msvd_qa.json
│   ├── pretrain_alpro.json
│   ├── pretrain_prompter.json
│   ├── timesformer_divst_8x32_224_k600.json
│   └── timesformer_divst_8x32_224_k600_gc.json
├── env/
│   ├── install_pkg.sh
│   └── requirements.txt
├── run_scripts/
│   ├── clear_cuda_cache.sh
│   ├── ft_didemo_ret.sh
│   ├── ft_msrvtt_qa.sh
│   ├── ft_msrvtt_ret.sh
│   ├── ft_msvd_qa.sh
│   ├── inf_didemo_ret.sh
│   ├── inf_msrvtt_qa.sh
│   ├── inf_msrvtt_ret.sh
│   ├── inf_msvd_qa.sh
│   ├── pt_alpro.sh
│   └── pt_prompter.sh
└── src/
    ├── __init__.py
    ├── configs/
    │   └── config.py
    ├── datasets/
    │   ├── data_utils.py
    │   ├── dataloader.py
    │   ├── dataset_base.py
    │   ├── dataset_pretrain_sparse.py
    │   ├── dataset_video_qa.py
    │   ├── dataset_video_retrieval.py
    │   └── randaugment.py
    ├── modeling/
    │   ├── alpro_models.py
    │   ├── timesformer/
    │   │   ├── __init__.py
    │   │   ├── conv2d_same.py
    │   │   ├── features.py
    │   │   ├── helpers.py
    │   │   ├── linear.py
    │   │   ├── operators.py
    │   │   ├── vit.py
    │   │   └── vit_utils.py
    │   ├── transformers.py
    │   └── xbert.py
    ├── optimization/
    │   ├── adamw.py
    │   ├── sched.py
    │   └── utils.py
    ├── pretrain/
    │   ├── run_pretrain_contrastive_only.py
    │   └── run_pretrain_sparse.py
    ├── tasks/
    │   ├── run_video_qa.py
    │   └── run_video_retrieval.py
    └── utils/
        ├── basic_utils.py
        ├── distributed.py
        ├── grad_ckpt.py
        ├── load_save.py
        ├── logger.py
        └── misc.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.vscode 

# script
tmp_all/script/

# Philly-realted #
pt/
.ptconfig



# Project-related   #
*/*results*/
*results*/
tmp*/
cache/*
*/cache*/
tmp*.py
*pickle

# compiled files #
*.pyc

# Packages #
############
# it's better to unpack these files and commit the raw source
# git has its own built in compression methods
*.7z
*.dmg
*.gz
*.iso
*.jar
*.rar
*.tar
*.zip

# Logs and databases #
######################
*.log
*.sql
*.sqlite
.ipynb_checkpoints/
*.swp
*.vscode/
*.idea/

# OS generated files #
######################
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db

# project-specific
img
txt
ext
data
output
src/configs_local


================================================
FILE: CODEOWNERS
================================================
# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
#ECCN:Open Source


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Salesforce Open Source Community Code of Conduct

## About the Code of Conduct

Equality is a core value at Salesforce. We believe a diverse and inclusive
community fosters innovation and creativity, and are committed to building a
culture where everyone feels included.

Salesforce open-source projects are committed to providing a friendly, safe, and
welcoming environment for all, regardless of gender identity and expression,
sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 
race, age, religion, level of experience, education, socioeconomic status, or 
other similar personal characteristics.

The goal of this code of conduct is to specify a baseline standard of behavior so
that people with different social values and communication styles can work
together effectively, productively, and respectfully in our open source community.
It also establishes a mechanism for reporting issues and resolving conflicts.

All questions and reports of abusive, harassing, or otherwise unacceptable behavior
in a Salesforce open-source project may be reported by contacting the Salesforce
Open Source Conduct Committee at ossconduct@salesforce.com.

## Our Pledge

In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of gender 
identity and expression, sexual orientation, disability, physical appearance, 
body size, ethnicity, nationality, race, age, religion, level of experience, education, 
socioeconomic status, or other similar personal characteristics.

## Our Standards

Examples of behavior that contributes to creating a positive environment
include:

* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy toward other community members

Examples of unacceptable behavior by participants include:

* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Personal attacks, insulting/derogatory comments, or trolling
* Public or private harassment
* Publishing, or threatening to publish, others' private information—such as
a physical or electronic address—without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
* Advocating for or encouraging any of the above behaviors

## Our Responsibilities

Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.

Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned with this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.

## Scope

This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project email
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the Salesforce Open Source Conduct Committee 
at ossconduct@salesforce.com. All complaints will be reviewed and investigated 
and will result in a response that is deemed necessary and appropriate to the 
circumstances. The committee is obligated to maintain confidentiality with 
regard to the reporter of an incident. Further details of specific enforcement 
policies may be posted separately.

Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership and the Salesforce Open Source Conduct 
Committee.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 
It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 
[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].

This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].

[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
[golang-coc]: https://golang.org/conduct
[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/



================================================
FILE: CONTRIBUTING-ARCHIVED.md
================================================
# ARCHIVED

This project is `Archived` and is no longer actively maintained;
We are not accepting contributions or Pull Requests.



================================================
FILE: LICENSE
================================================
Copyright (c) 2021, Salesforce.com, Inc.
All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

================================================
FILE: README.md
================================================
# ALPRO (CVPR 22')

## ALPRO is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS), a one-stop library for language-vision intelligence!

## Align and Prompt: Video-and-Language Pre-training with Entity Prompts [[Paper](https://arxiv.org/abs/2112.09583)]

[Dongxu Li](https://www.linkedin.com/in/dongxu-li-a8a035110/), [Junnan Li](https://sites.google.com/site/junnanlics), [Hongdong Li](http://users.cecs.anu.edu.au/~hongdong/), [Juan Carlos Niebles](http://www.niebles.net/), [Steven C.H. Hoi](https://sites.google.com/view/stevenhoi/home)

<img src="pics/teaser.jpg" width="500">

Official PyTorch code for ALPRO. This repository supports pre-training as well as finetuning on 
- Text-Video Retrieval on MSRVTT and DiDeMo.
- Video Question Anwsering on MSRVTT and MSVD.

## Requirements
Our implementation is tested on Ubuntu 20.04.1 with NVIDIA A100 GPUs. Supports for other platforms and hardwares are possible with no warrant. To install the required packages:

```bash
cd env && bash install_pkg.sh
```

## Data Preparation 
1. Download Annotations and Pre-trained Checkpoints
    - [Text annotations](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/data.zip)
    - [Checkpoints of pre-trained model and finetuned model](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/output.zip)
    - [Externel resources](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/ext.zip)
    - unzip `data.zip`, `output.zip`, `ext.zip` under `ALPRO/`.
 
2. Download raw videos of downstream datasets.
   - MSRVTT:
     - download train_val_videos.zip and test_videos.zip from e.g. [here](https://www.mediafire.com/folder/h14iarbs62e7p/shared).
     - check md5sum:

        ```bash
        51f2394d279cf84f1642defd9a651e6f  train_val_videos.zip
        0af68454cec9d586e92805739f3911d0  test_videos.zip
        ```
     - unzip all the videos into `data/msrvtt_ret/videos` (10k in total).
     - create the following soft link:

        ```bash
        ln -s data/msrvtt_ret/videos data/msrvtt_qa/videos```
    - MSVD:
      - download from official release:
  
        ```bash
        wget -nc https://www.cs.utexas.edu/users/ml/clamp/videoDescription/YouTubeClips.tar
        ```
      - check md5sum:
      
        ```bash
        9bdb20fcf14d59524a6febca9f6a8d89  YouTubeClips.tar
        ```
      - unzip all the videos to `data/msvd_qa/videos` (1,970 videos in total).
        
        ```bash
        mkdir data/msvd_qa/videos/ 
        tar xvf YouTubeClips.tar -C data/msvd_qa/videos --strip-components=1
        ```
    - DiDeMo:
       - Following [instructions](https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md) and download from the official release [here](https://drive.google.com/drive/u/1/folders/1_oyJ5rQiZboipbMl6tkhY8v0s9zDkvJc);
       - unzip all the videos into `data/didemo_ret/videos`.
       - Note there might be a couple videos missing. See [here](https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md#getting-the-videos) to download. However, as they account for a small portion of training set, you may feel safe to ignore.
       - Convert all the DiDeMo videos into `*.mp4` format using e.g. [`ffmpeg`](https://askubuntu.com/questions/396883/how-to-simply-convert-video-files-i-e-mkv-to-mp4).
       - We obtained 10,463 videos following these steps (with one video `77807177@N00_5753455690_1e04ccb364` missing).



  3. The directory is expected to be in the structure below:
      ```bash
      .
      |-config_release  # configuration files
      |-data  # text annotations and raw videos
      |---didemo_ret
      |-----txt
      |-----videos
      |---msrvtt_qa/...
      |---msrvtt_ret/...
      |---msvd_qa/...
      |-env  # scripts to install packages
      |-ext  # external resources, e.g. bert tokenizer
      |-output  # checkpoints for pre-trained/finetuned models
      |---downstreams
      |-----didemo_ret
      |-------public
      |---------ckpt # official finetuned checkpoints
      |---------log # inference log
      |---------results_test
      |-----------step_best_1_mean
      |-----msrvtt_qa/...
      |-----msrvtt_ret/...
      |-----msvd_qa/...
      |-run_scripts  # bash scripts to launch experiments
      |-src  # source code
      ```

## Inference with Official Checkpoints

  ```bash
  cd run_scripts
  bash inf_msrvtt_ret.sh
  # {'text2video': {'r1': 33.9, 'r5': 60.7, 'r10': 73.2, 'medianR': 3.0, 'meanR': 27.404}}
  bash inf_didemo_ret.sh
  # {'text2video': {'r1': 35.9, 'r5': 67.5, 'r10': 78.8, 'medianR': 3.0, 'meanR': 19.125}}
  bash inf_msrvtt_qa.sh
  # {'ratios': {'what_ratio': [68.48, 49872], 'who_ratio': [27.99, 20385], 'how_ratio': [2.25, 1640], 'where_ratio': [0.34, 250], 'when_ratio': [0.93, 677]}, 'overall_acc': 42.12, 'what_acc': 36.05, 'who_acc': 52.24, 'how_acc': 85.67, 'where_acc': 42.8, 'when_acc': 78.88}
  bash inf_msvd_qa.sh
  # {'ratios': {'what_ratio': [61.93, 8150], 'who_ratio': [34.6, 4554], 'how_ratio': [2.81, 370], 'where_ratio': [0.21, 28], 'when_ratio': [0.44, 58]}, 'overall_acc': 45.91, 'what_acc': 37.02, 'who_acc': 58.59, 'how_acc': 81.62, 'where_acc': 46.43, 'when_acc': 72.41}
  ```


## Downstream Task Finetuning
  - To finetune on downstream tasks with the pre-trained checkpoint `output/pretrain/alpro_pretrained_ckpt.pt`

    ```bash
    cd run_scripts
    bash ft_msrvtt_ret.sh
    bash ft_didemo_ret.sh
    bash ft_msrvtt_qa.sh
    bash ft_msvd_qa.sh
    ```
  
    For example, with MSRVTT retrieval:
    ```bash
    cd ALPRO/

    export PYTHONPATH="$PYTHONPATH:$PWD"
    echo $PYTHONPATH

    CONFIG_PATH='config_release/msrvtt_ret.json'

    horovodrun -np 8 python src/tasks/run_video_retrieval.py \ # change -np to GPUs numbers.
        --config $CONFIG_PATH \
        --output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_ret/$(date '+%Y%m%d%H%M%S')  # change to your local path to store finetuning ckpts and logs 
    ``` 
 - Run inference with locally-finetuned checkpoints.
   ```bash
    cd ALPRO/

    export PYTHONPATH="$PYTHONPATH:$PWD"
    echo $PYTHONPATH

    STEP='best'

    CONFIG_PATH='config_release/msrvtt_ret.json'
    OUTPUT_DIR='[INPUT_YOUR_OUTPUT_PATH_HERE]'

    TXT_DB='data/msrvtt_ret/txt/test.jsonl'
    IMG_DB='data/msrvtt_ret/videos'

    horovodrun -np 8 python src/tasks/run_video_retrieval.py \
        --do_inference 1 \
        --inference_split test \
        --inference_model_step $STEP \
        --inference_txt_db $TXT_DB \
        --inference_img_db $IMG_DB \
        --inference_batch_size 64 \
        --output_dir $OUTPUT_DIR \
        --config $CONFIG_PATH
   ```  
   - `OUTPUT_DIR` is the path after the `--output_dir` option in the finetuning script.
   - `$STEP` is a string, which tells the script to use the checkpoint `$OUTPUT_DIR/ckpt/model_step_$STEP.pt` for inference. 


## Pretraining
1. Download [WebVid2M](https://github.com/m-bain/frozen-in-time) and [CC-3M](https://github.com/igorbrigadir/DownloadConceptualCaptions).
  
    - Put WebVid2M videos under `data/webvid2m`;
    - 💡 we downsample webvid2m videos to 10% of the original FPS to speed-up video loading;
    - change `data/cc3m/txt/cc3m.json` with local image paths.

2. Training Prompter:
    ```bash
    cd run_scripts && bash pt_prompter.sh
    ```   

3. Training video-language model: 
    ```bash
    cd run_scripts && bash pt_alpro.sh
    ```
    If you would like to use custom prompter weight, please change `teacher_weights_path` in `config_release/pretrain_alpro.json`
4. To finetune with pre-trained checkpoints, please change `e2e_weights_path` in the finetuning config files, e.g. `config_release/msrvtt_ret.json`.


## Citation

If you find ALPRO useful for your research, please consider citing:
```bibtex
  @inproceedings{li2021align,
    title={Align and Prompt: Video-and-Language Pre-training with Entity Prompts},
    author={Dongxu Li, Junnan Li, Hongdong Li, Juan Carlos Niebles, Steven C.H. Hoi},
    booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
    year={2022}
  }
```

## Acknowledgement
We thank members at Salesforce Research for their helpful discussions.

The implementation of ALPRO relies on resources from [ClipBERT](https://github.com/jayleicn/ClipBERT),
[transformers](https://github.com/huggingface/transformers), 
[TimeSformer](https://github.com/facebookresearch/TimeSformer/tree/main/timesformer/models), 
The code is implemented using [PyTorch](https://github.com/pytorch/pytorch), 
with multi-GPU support from [Horovod](https://github.com/horovod/horovod) and [gradient-checkpoint](https://github.com/csrhddlam/pytorch-checkpoint).  We thank the original authors for their open-sourcing and encourage ALPRO users to cite their works when applicable.



================================================
FILE: SECURITY.md
================================================
## Security

Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
as soon as it is discovered. This library limits its runtime dependencies in
order to reduce the total cost of ownership as much as can be, but all consumers
should remain vigilant and have their security stakeholders review all third-party
products (3PP) like this one and their dependencies.


================================================
FILE: config_release/base_model.json
================================================
{
    "attention_probs_dropout_prob": 0.1,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "hidden_size": 768,
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "layer_norm_eps": 1e-12,
    "max_position_embeddings": 512,
    "model_type": "bert",
    "num_attention_heads": 12,
    "num_hidden_layers": 12,
    "pad_token_id": 0,
    "type_vocab_size": 2,
    "vocab_size": 30522,
    "fusion_layer": 6,
    "encoder_width": 768,
    "itc_token_type": "cls"
}


================================================
FILE: config_release/didemo_ret.json
================================================
{
  "train_datasets": [
    {
      "name": "didemo",
      "txt": "data/didemo_ret/txt/train.jsonl",
      "img": "data/didemo_ret/videos"
    }
  ],
  "val_datasets": [
    {
      "name": "didemo_retrieval",
      "txt": "data/didemo_ret/txt/val.jsonl",
      "img": "data/didemo_ret/videos"
    }
  ],
  "max_txt_len": 50,
  "crop_img_size": 224,
  "resize_size": 256,
  "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 
  "img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
  "img_input_format": "RGB",
  "num_frm": 8,
  "train_n_clips": 1,
  "max_n_example_per_group": 1,
  "model_config": "config_release/base_model.json",
  "tokenizer_dir": "ext/bert-base-uncased/",
  "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json",
  "e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt",
  "bert_weights_path": null,
  "train_batch_size": 12,
  "val_batch_size": 12,
  "gradient_accumulation_steps": 1,
  "num_train_epochs": 10,
  "min_valid_steps": 20,
  "num_valid": 20,
  "learning_rate": 4e-5,
  "weight_decay": 1e-3,
  "decay": "linear",
  "optim": "adamw",
  "betas": [0.9, 0.98],
  "dropout": 0.1,
  "grad_norm": 20.0,
  "seed":42,
  "fp16": 0,
  "num_workers": 4
}


================================================
FILE: config_release/msrvtt_qa.json
================================================
{
  "train_datasets": [
    {
      "name": "msrvtt_qa",
      "txt": {
        "msrvtt_qa": "data/msrvtt_qa/txt/train.jsonl"
      },
      "img": "data/msrvtt_qa/videos"
    }
  ],
  "val_datasets": [
    {
      "name": "msrvtt_qa",
      "txt": {
        "msrvtt_qa": "data/msrvtt_qa/txt/val.jsonl"
      },
      "img": "data/msrvtt_qa/videos"
    }
  ],
  "ans2label_path": "data/msrvtt_qa/txt/train_ans2label.json",
  "max_txt_len": 40,
  "crop_img_size": 224,
  "resize_size": 256,
  "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 
  "img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
  "img_input_format": "RGB",
  "train_n_clips": 1,
  "model_config": "config_release/base_model.json",
  "tokenizer_dir": "ext/bert-base-uncased/",
  "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600_gc.json",
  "e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt",
  "num_frm": 16,
  "train_batch_size": 12,
  "val_batch_size": 12,
  "gradient_accumulation_steps": 2,
  "num_train_epochs": 10,
  "min_valid_steps": 50,
  "num_valid": 50,
  "learning_rate": 5e-5,
  "weight_decay": 1e-3,
  "decay": "linear",
  "optim": "adamw",
  "betas": [0.9, 0.98],
  "dropout": 0.1,
  "grad_norm": 5.0,
  "cnn_lr_decay": "linear",
  "seed":42,
  "fp16": 0,
  "classifier": "mlp",
  "cls_hidden_scale": 2,
  "task": "msrvtt_qa",
  "num_workers": 4
}


================================================
FILE: config_release/msrvtt_ret.json
================================================
{
  "train_datasets": [
    {
      "name": "msrvtt",
      "txt": "data/msrvtt_ret/txt/train.jsonl",
      "img": "data/msrvtt_ret/videos"
    }
  ],
  "val_datasets": [
    {
      "name": "msrvtt_retrieval",
      "txt": "data/msrvtt_ret/txt/val.jsonl",
      "img": "data/msrvtt_ret/videos"
    }
  ],
  "max_txt_len": 40,
  "crop_img_size": 224,
  "resize_size": 256,
  "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 
  "img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
  "img_input_format": "RGB",
  "train_n_clips": 1,
  "model_config": "config_release/base_model.json",
  "tokenizer_dir": "ext/bert-base-uncased/",
  "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json",
  "e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt",
  "num_frm": 8,
  "train_batch_size": 8,
  "val_batch_size": 8,
  "gradient_accumulation_steps": 1,
  "num_train_epochs": 5,
  "min_valid_steps": 100,
  "num_valid": 20,
  "learning_rate": 2.5e-5,
  "weight_decay": 1e-3,
  "decay": "linear",
  "optim": "adamw",
  "betas": [0.9, 0.98],
  "dropout": 0.1,
  "grad_norm": 5.0,
  "seed":42,
  "fp16": 0,
  "num_workers": 4
}


================================================
FILE: config_release/msvd_qa.json
================================================
{
  "train_datasets": [
    {
      "name": "msvd_qa",
      "txt": {
        "msvd_qa": "data/msvd_qa/txt/train.jsonl"
      },
      "img": "data/msvd_qa/videos"
    }
  ],
  "val_datasets": [
    {
      "name": "msvd_qa",
      "txt": {
        "msvd_qa": "data/msvd_qa/txt/val.jsonl"
      },
      "img": "data/msvd_qa/videos"
    }
  ],
  "ans2label_path": "data/msvd_qa/txt/train_ans2label.json",
  "num_labels": 2423,
  "max_txt_len": 40,
  "crop_img_size": 224,
  "resize_size": 256,
  "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 
  "img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
  "img_input_format": "RGB",
  "train_n_clips": 1,
  "num_frm": 16,
  "model_config": "config_release/base_model.json",
  "tokenizer_dir": "ext/bert-base-uncased/",
  "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600_gc.json",
  "e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt",
  "train_batch_size": 12,
  "val_batch_size": 12,
  "gradient_accumulation_steps": 2,
  "num_train_epochs": 15,
  "min_valid_steps": 50,
  "num_valid": 30,
  "learning_rate": 5e-5,
  "weight_decay": 1e-3,
  "decay": "linear",
  "optim": "adamw",
  "betas": [0.9, 0.98],
  "dropout": 0.1,
  "grad_norm": 20.0,
  "cnn_lr_decay": "linear",
  "seed":42,
  "fp16": 0,
  "save_steps_ratio": 0.05,
  "classifier": "mlp",
  "cls_hidden_scale": 2,
  "task": "msvd_qa",
  "num_workers": 4
}


================================================
FILE: config_release/pretrain_alpro.json
================================================
{
  "train_datasets": [
    {
      "name": "webvid2m",
      "ann": "data/webvid2m/txt/train.pkl",
      "txt": null,
      "img": "data/webvid2m/videos"
    },
    {
      "name": "cc3m",
      "ann": "data/cc3m/txt/cc3m.json",
      "txt": null,
      "img": null 
    }
  ],
  "val_datasets": [
    {
      "name": "webvid2m",
      "ann": "data/webvid2m/txt/val.pkl",
      "txt": null,
      "img": "data/webvid2m/videos"
    }
  ],
  "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 
  "img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
  "img_input_format": "RGB",
  "model_type": "pretrain",
  "model_config": "config_release/base_model.json",
  "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json",
  "visual_weights_path": "vit_base_patch16_224",
  "teacher_weights_path": "output/pretrain/prompter_pretrained.pt",
  "entity_file_path": "data/unigrams.txt",
  "tokenizer_dir": "ext/bert-base-uncased/",
  "max_txt_len": 30,
  "crop_img_size": 224,
  "resize_size": 256,
  "train_batch_size": 16,
  "val_batch_size": 16,
  "gradient_accumulation_steps": 1,
  "num_train_epochs": 10,
  "min_valid_steps": 10,
  "num_valid": 10,
  "learning_rate": 1e-4,
  "decay": "linear",
  "optim": "adamw",
  "betas": [0.9, 0.98],
  "dropout": 0.1,
  "weight_decay": 1e-3,
  "grad_norm": 20.0,
  "seed":42,
  "fp16": 0,
  "use_itm": 1,
  "use_mlm": 1,
  "use_itc": 1,
  "use_mpm": 1,
  "n_workers": 4,
  "save_steps_ratio": 0.01,
  "frm_sampling_strategy": "headtail",
  "num_frm": 4,
  "fps": 0.5,
  "debug": false,
  "warmup_ratio": 0.05,
  "log_interval": 100
}


================================================
FILE: config_release/pretrain_prompter.json
================================================
{
  "train_datasets": [
    {
      "name": "webvid2m",
      "ann": "data/webvid2m/txt/train.pkl",
      "txt": null,
      "img": "data/webvid2m/videos"
    },
    {
      "name": "cc3m",
      "ann": "data/cc3m/txt/cc3m.json",
      "txt": null,
      "img": null 
    }
  ],
  "val_datasets": [
    {
      "name": "webvid2m",
      "ann": "data/webvid2m/txt/val.pkl",
      "txt": null,
      "img": "data/webvid2m/videos"
    }
  ],
  "img_pixel_mean": [0.48145466, 0.4578275, 0.40821073], 
  "img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
  "img_input_format": "RGB",
  "model_type": "pretrain",
  "model_config": "config_release/base_model.json",
  "visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json",
  "visual_weights_path": "vit_base_patch16_224",
  "tokenizer_dir": "ext/bert-base-uncased/",
  "max_txt_len": 30,
  "crop_img_size": 224,
  "resize_size": 256,
  "train_batch_size": 16,
  "val_batch_size": 16,
  "gradient_accumulation_steps": 2,
  "num_train_epochs": 10,
  "min_valid_steps": 100,
  "num_valid": 10,
  "learning_rate": 1e-4,
  "decay": "linear",
  "optim": "adamw",
  "betas": [0.9, 0.98],
  "dropout": 0.1,
  "weight_decay": 1e-3,
  "grad_norm": 20.0,
  "seed":42,
  "fp16": 0,
  "use_itm": 0,
  "use_mlm": 0,
  "use_itc": 1,
  "n_workers": 4,
  "save_steps_ratio": 0.05,
  "frm_sampling_strategy": "headtail",
  "num_frm": 4,
  "debug": false,
  "warmup_ratio": 0.05,
  "log_interval": 100
}


================================================
FILE: config_release/timesformer_divst_8x32_224_k600.json
================================================
{
    "cls": "TimeSformer",
    "patch_size": 16,
    "attn_drop_rate": 0,
    "drop_rate": 0,
    "drop_path_rate": 0.1,
    "maxpool_kernel_size": 2,
    "use_maxpooling": false,
    "gradient_checkpointing": false
}


================================================
FILE: config_release/timesformer_divst_8x32_224_k600_gc.json
================================================
{
    "cls": "TimeSformer",
    "patch_size": 16,
    "attn_drop_rate": 0,
    "drop_rate": 0,
    "drop_path_rate": 0.1,
    "maxpool_kernel_size": 2,
    "use_maxpooling": false,
    "gradient_checkpointing": true
}


================================================
FILE: env/install_pkg.sh
================================================
apt update
apt install lsof

# horovod
HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \
    pip install --no-cache-dir horovod==0.19.4 &&\
    ldconfig

# use the faster pillow-simd instead of the original pillow
# https://github.com/uploadcare/pillow-simd
pip uninstall pillow && \
CC="cc -mavx2" pip install -U --force-reinstall pillow-simd

spacy download en

pip install -r requirements.txt

git clone https://github.com/NVIDIA/apex.git &&\
    cd apex &&\
    pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . &&\
    rm -rf ../apex



================================================
FILE: env/requirements.txt
================================================
ipdb
joblib
cytoolz
lz4==2.1.9
lmdb==0.97
msgpack-numpy
msgpack
toolz
transformers==4.11.3
tensorboard
tqdm
easydict
pycocotools>=2.0.1
opencv-python
tensorboardX==2.0
av==8.0.2
ujson
einops
decord
timm


================================================
FILE: run_scripts/clear_cuda_cache.sh
================================================
for i in $(lsof /dev/nvidia* | grep python  | awk '{print $2}' | sort -u); do kill -9 $i; done





================================================
FILE: run_scripts/ft_didemo_ret.sh
================================================
cd .. 

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

CONFIG_PATH='config_release/didemo_ret.json'

horovodrun -np 8 python src/tasks/run_video_retrieval.py \
      --config $CONFIG_PATH \
      --output_dir /export/home/workspace/experiments/alpro/finetune/didemo_ret/$(date '+%Y%m%d%H%M%S')


================================================
FILE: run_scripts/ft_msrvtt_qa.sh
================================================
cd .. 

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

CONFIG_PATH='config_release/msrvtt_qa.json'

horovodrun -np 8 python src/tasks/run_video_qa.py \
      --config $CONFIG_PATH \
      --output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_qa/$(date '+%Y%m%d%H%M%S')


================================================
FILE: run_scripts/ft_msrvtt_ret.sh
================================================
cd .. 

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

CONFIG_PATH='config_release/msrvtt_ret.json'

horovodrun -np 8 python src/tasks/run_video_retrieval.py \
      --config $CONFIG_PATH \
      --output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_ret/$(date '+%Y%m%d%H%M%S')

================================================
FILE: run_scripts/ft_msvd_qa.sh
================================================
cd .. 

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

CONFIG_PATH='config_release/msvd_qa.json'

horovodrun -np 8 python src/tasks/run_video_qa.py \
      --config $CONFIG_PATH \
      --output_dir /export/home/workspace/experiments/alpro/finetune/msvd_qa/$(date '+%Y%m%d%H%M%S')


================================================
FILE: run_scripts/inf_didemo_ret.sh
================================================
cd .. 

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

STEP='best'

CONFIG_PATH='config_release/didemo_ret.json'

TXT_DB='data/didemo_ret/txt/test.jsonl'
IMG_DB='data/didemo_ret/videos'

horovodrun -np 8 python src/tasks/run_video_retrieval.py \
      --do_inference 1 \
      --inference_split test \
      --inference_model_step $STEP \
      --inference_txt_db $TXT_DB \
      --inference_img_db $IMG_DB \
      --inference_batch_size 64 \
      --output_dir output/downstreams/didemo_ret/public \
      --config $CONFIG_PATH

================================================
FILE: run_scripts/inf_msrvtt_qa.sh
================================================
cd .. 

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

STEP='best'

CONFIG_PATH='config_release/msrvtt_qa.json'

TXT_DB='data/msrvtt_qa/txt/test.jsonl'
IMG_DB='data/msrvtt_qa/videos'

horovodrun -np 8 python src/tasks/run_video_qa.py \
      --do_inference 1 \
      --inference_split test \
      --inference_model_step $STEP \
      --inference_txt_db $TXT_DB \
      --inference_img_db $IMG_DB \
      --inference_batch_size 64 \
      --output_dir output/downstreams/msrvtt_qa/public \
      --config $CONFIG_PATH

================================================
FILE: run_scripts/inf_msrvtt_ret.sh
================================================
cd ..

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

STEP='best'

CONFIG_PATH='config_release/msrvtt_ret.json'

TXT_DB='data/msrvtt_ret/txt/test.jsonl'
IMG_DB='data/msrvtt_ret/videos'

horovodrun -np 8 python src/tasks/run_video_retrieval.py \
      --do_inference 1 \
      --inference_split test \
      --inference_model_step $STEP \
      --inference_txt_db $TXT_DB \
      --inference_img_db $IMG_DB \
      --inference_batch_size 64 \
      --output_dir  output/downstreams/msrvtt_ret/public \
      --config $CONFIG_PATH

================================================
FILE: run_scripts/inf_msvd_qa.sh
================================================
cd .. 

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

STEP='best'

CONFIG_PATH='config_release/msvd_qa.json'

TXT_DB='data/msvd_qa/txt/test.jsonl'
IMG_DB='data/msvd_qa/videos'

horovodrun -np 8 python src/tasks/run_video_qa.py \
      --do_inference 1 \
      --inference_split test \
      --inference_model_step $STEP \
      --inference_txt_db $TXT_DB \
      --inference_img_db $IMG_DB \
      --inference_batch_size 64 \
      --output_dir output/downstreams/msvd_qa/public \
      --config $CONFIG_PATH

================================================
FILE: run_scripts/pt_alpro.sh
================================================
cd ..

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

CONFIG_PATH='config_release/pretrain_alpro.json'

horovodrun -np 16 python src/pretrain/run_pretrain_sparse.py \
      --config $CONFIG_PATH \
      --output_dir /export/home/workspace/experiments/alpro/vl/$(date '+%Y%m%d%H%M%S')

================================================
FILE: run_scripts/pt_prompter.sh
================================================
cd .. 

export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH

CONFIG_PATH='config_release/pretrain_prompter.json'

horovodrun -np 8 python src/pretrain/run_pretrain_contrastive_only.py \
      --config $CONFIG_PATH \
      --output_dir /export/home/workspace/experiments/alpro/prompter/$(date '+%Y%m%d%H%M%S')

================================================
FILE: src/__init__.py
================================================


================================================
FILE: src/configs/config.py
================================================
"""
Modified from UNITER code
"""
import os
import sys
import json
import argparse

from easydict import EasyDict as edict


def parse_with_config(parsed_args):
    """This function will set args based on the input config file.
    (1) it only overwrites unset parameters,
        i.e., these parameters not set from user command line input
    (2) it also sets configs in the config file but declared in the parser
    """
    # convert to EasyDict object, enabling access from attributes even for nested config
    # e.g., args.train_datasets[0].name
    args = edict(vars(parsed_args))
    if args.config is not None:
        config_args = json.load(open(args.config))
        override_keys = {arg[2:].split("=")[0] for arg in sys.argv[1:]
                         if arg.startswith("--")}
        for k, v in config_args.items():
            if k not in override_keys:
                setattr(args, k, v)
    del args.config
    return args


class SharedConfigs(object):
    """Shared options for pre-training and downstream tasks.
    For each downstream task, implement a get_*_args function,
    see `get_pretraining_args()`

    Usage:
    >>> shared_configs = SharedConfigs()
    >>> pretraining_config = shared_configs.get_pretraining_args()
    """

    def __init__(self, desc="shared config for pretraining and finetuning"):
        parser = argparse.ArgumentParser(description=desc)
        # debug parameters
        parser.add_argument(
            "--debug", type=int, choices=[0, 1], default=0,
            help="debug mode, output extra info & break all loops."
                 "0: disable, 1 enable")
        parser.add_argument(
            "--data_ratio", type=float, default=1.0,
            help="portion of train/val exampels to use,"
                 "e.g., overfit a small set of data")

        # Required parameters
        parser.add_argument(
            "--model_config", type=str,
            help="path to model structure config json")
        parser.add_argument(
            "--tokenizer_dir", type=str, help="path to tokenizer dir")
        parser.add_argument(
            "--output_dir", type=str,
            help="dir to store model checkpoints & training meta.")

        # data preprocessing parameters
        parser.add_argument(
            "--max_txt_len", type=int, default=20, help="max text #tokens ")
        # parser.add_argument(
        #     "--max_img_size", type=int, default=448,
        #     help="max image longer side size, shorter side will be padded with zeros")
        parser.add_argument(
            "--img_pixel_mean", type=float, default=None,
            nargs=3, help="image pixel mean")
        parser.add_argument(
            "--img_pixel_std", type=float, default=None,
            nargs=3, help="image pixel std")
        parser.add_argument(
            "--img_input_format", type=str, default="BGR",
            choices=["BGR", "RGB"], help="image input format is BGR for detectron2")
        parser.add_argument(
            "--max_n_example_per_group", type=int, default=1,
            help="max #examples (e.g., captions) paired with each image/video in an input group."
                 "1: each image is paired with a single sent., equivalent to sample by sent.;"
                 "X (X>1): each image can be paired with a maximum of X sent.; X>1 can be used "
                 "to reduce image processing time, including basic transform (resize, etc) and CNN encoding"
        )
        # video specific parameters
        parser.add_argument("--fps", type=int, default=1, help="video frame rate to use")
        parser.add_argument("--num_frm", type=int, default=3,
                            help="#frames to use per clip -- we first sample a clip from a video, "
                                 "then uniformly sample num_frm from the clip. The length of the clip "
                                 "will be fps * num_frm")
        parser.add_argument("--frm_sampling_strategy", type=str, default="rand",
                            choices=["rand", "uniform", "start", "middle", "end"],
                            help="see src.datasets.dataset_base.extract_frames_from_video_binary for details")

        # MLL training settings
        parser.add_argument("--train_n_clips", type=int, default=3,
                            help="#clips to sample from each video for MIL training")
        parser.add_argument("--score_agg_func", type=str, default="mean",
                            choices=["mean", "max", "lse"],
                            help="score (from multiple clips) aggregation function, lse = LogSumExp")
        parser.add_argument("--random_sample_clips", type=int, default=1, choices=[0, 1],
                            help="randomly sample clips for training, otherwise use uniformly sampled clips.")

        # training parameters
        parser.add_argument(
            "--train_batch_size", default=128, type=int,
            help="Single-GPU batch size for training for Horovod.")
        parser.add_argument(
            "--val_batch_size", default=128, type=int,
            help="Single-GPU batch size for validation for Horovod.")
        parser.add_argument(
            "--gradient_accumulation_steps", type=int, default=1,
            help="#updates steps to accumulate before performing a backward/update pass."
                 "Used to simulate larger batch size training. The simulated batch size "
                 "is train_batch_size * gradient_accumulation_steps for a single GPU.")
        parser.add_argument("--learning_rate", default=5e-5, type=float,
                            help="initial learning rate.")
        parser.add_argument(
            "--log_interval", default=500, type=int,
            help="record every a few steps on tensorboard.")
        parser.add_argument(
            "--num_valid", default=20, type=int,
            help="Run validation X times during training and checkpoint.")
        parser.add_argument(
            "--min_valid_steps", default=100, type=int,
            help="minimum #steps between two validation runs")
        parser.add_argument(
            "--save_steps_ratio", default=0.01, type=float,
            help="save every 0.01*global steps to resume after preemption,"
                 "not used for checkpointing.")
        parser.add_argument("--num_train_epochs", default=10, type=int,
                            help="Total #training epochs.")
        parser.add_argument("--optim", default="adamw",
                            choices=["adam", "adamax", "adamw"],
                            help="optimizer")
        parser.add_argument("--betas", default=[0.9, 0.98],
                            nargs=2, help="beta for adam optimizer")
        parser.add_argument("--decay", default="linear",
                            choices=["linear", "invsqrt"],
                            help="learning rate decay method")
        parser.add_argument("--dropout", default=0.1, type=float,
                            help="tune dropout regularization")
        parser.add_argument("--weight_decay", default=1e-3, type=float,
                            help="weight decay (L2) regularization")
        parser.add_argument("--grad_norm", default=2.0, type=float,
                            help="gradient clipping (-1 for no clipping)")
        parser.add_argument(
            "--warmup_ratio", default=0.1, type=float,
            help="to perform linear learning rate warmup for. (invsqrt decay)")
        parser.add_argument("--transformer_lr_mul", default=1.0, type=float,
                            help="lr_mul for transformer")
        parser.add_argument("--step_decay_epochs", type=int,
                            nargs="+", help="multi_step decay epochs")
        # model arch
        parser.add_argument(
            "--model_type", type=str, default="pretrain",
            help="type of e2e model to use. Support only 'pretrain' for now. ")
        parser.add_argument(
            "--timesformer_model_cfg", type=str, default="",
            help="path to timesformer model cfg yaml")

        # checkpoint
        parser.add_argument("--e2e_weights_path", type=str,
                            help="path to e2e model weights")
        parser.add_argument(
            "--clip_init", default=0, type=int, choices=[0, 1],
            help="1 for using clip ckpt for init.")
        parser.add_argument("--bert_weights_path", type=str,
                            help="path to BERT weights, only use for pretraining")

        # inference only, please include substring `inference'
        # in the option to avoid been overwrite by loaded options,
        # see start_inference() in run_vqa_w_hvd.py
        parser.add_argument("--inference_model_step", default=-1, type=str,
                            help="pretrained model checkpoint step")
        parser.add_argument(
            "--do_inference", default=0, type=int, choices=[0, 1],
            help="perform inference run. 0: disable, 1 enable")
        parser.add_argument(
            "--inference_split", default="val",
            help="For val, the data should have ground-truth associated it."
                 "For test*, the data comes with no ground-truth.")
        parser.add_argument("--inference_txt_db", type=str,
                            help="path to txt_db file for inference")
        parser.add_argument("--inference_img_db", type=str,
                            help="path to img_db file for inference")
        parser.add_argument("--inference_batch_size", type=int, default=64,
                            help="single-GPU batch size for inference")
        parser.add_argument("--inference_n_clips", type=int, default=1,
                            help="uniformly sample `ensemble_n_clips` clips, "
                                 "each contains `num_frm` frames. When it == 1, "
                                 "use the frm_sampling_strategy to sample num_frm frames."
                                 "When it > 1, ignore frm_sampling_strategy, "
                                 "uniformly sample N clips, each clips num_frm frames.")

        # device parameters
        parser.add_argument("--seed", type=int, default=42,
                            help="random seed for initialization")
        parser.add_argument(
            "--fp16", type=int, choices=[0, 1], default=0,
            help="Use 16-bit float precision instead of 32-bit."
                 "0: disable, 1 enable")
        parser.add_argument("--n_workers", type=int, default=4,
                            help="#workers for data loading")
        parser.add_argument("--pin_mem", type=int, choices=[0, 1], default=1,
                            help="pin memory. 0: disable, 1 enable")

        # can use config files, will only overwrite unset parameters
        parser.add_argument("--config", help="JSON config files")
        self.parser = parser

    def parse_args(self):
        parsed_args = self.parser.parse_args()
        args = parse_with_config(parsed_args)

        # convert to all [0, 1] options to bool, including these task specific ones
        zero_one_options = [
            "fp16", "pin_mem", "use_itm", "use_mlm", "use_itc", "debug", #"freeze_cnn",
            "do_inference",
        ]
        for option in zero_one_options:
            if hasattr(args, option):
                setattr(args, option, bool(getattr(args, option)))

        # basic checks
        # This is handled at TrainingRestorer
        # if exists(args.output_dir) and os.listdir(args.output_dir):
        #     raise ValueError(f"Output directory ({args.output_dir}) "
        #                      f"already exists and is not empty.")
        if args.step_decay_epochs and args.decay != "multi_step":
            Warning(
                f"--step_decay_epochs epochs set to {args.step_decay_epochs}"
                f"but will not be effective, as --decay set to be {args.decay}")

        assert args.gradient_accumulation_steps >= 1, \
            f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps} "

        assert 1 >= args.data_ratio > 0, \
            f"--data_ratio should be [1.0, 0), but get {args.data_ratio}"

        return args

    def get_sparse_pretraining_args(self):
        # pre-training args
        self.parser.add_argument(
            "--use_itm", type=int, choices=[0, 1], default=0,
            help="enable itm loss. 0: disable, 1 enable")
        self.parser.add_argument(
            "--use_mlm", type=int, choices=[0, 1], default=0,
            help="enable mlm loss. 0: disable, 1 enable")
        self.parser.add_argument(
            "--use_itc", type=int, choices=[0, 1], default=0,
            help="enable itc loss. 0: disable, 1 enable")
        
        # sparse pretraining-specific settings
        self.parser.add_argument(
            "--crop_img_size", type=int, default=256,
            help="crop size during pre-training.")
        self.parser.add_argument(
            "--resize_size", type=int, default=288,
            help="resize frames to square, ignoring aspect ratio.")

        # MPM-specific
        self.parser.add_argument(
            "--use_mpm", type=int, choices=[0, 1], default=0,
            help="enable mpm loss. 0: disable, 1 enable")
        self.parser.add_argument("--teacher_weights_path", type=str,
                            help="path to teacher model weights, only use for pretraining.")
        self.parser.add_argument("--entity_file_path", type=str,
                            help="path to selected NOUN entities.")
        self.parser.add_argument(
            "--num_entities", type=int, default=1000,
            help="maximum entities to consider for pseudo labels.")

        args = self.parse_args()
        return args

    def get_video_retrieval_args(self):
        self.parser.add_argument("--eval_retrieval_batch_size", type=int, default=256,
                                 help="batch size for retrieval, since each batch will only have one image, "
                                      "retrieval allows larger batch size")

        args = self.parse_args()
        return args

    def get_nlvl_args(self):
        args = self.parse_args()

        return args


    def get_vqa_args(self):
        self.parser.add_argument("--ans2label_path", type=str,
                                 help="path to {answer: label} file")
        self.parser.add_argument("--loss_type", type=str, default="bce",
                                 help="loss type")
        self.parser.add_argument("--classifier", type=str, default="mlp",
                                 choices=["mlp", "linear"],
                                 help="classifier type")
        self.parser.add_argument(
            "--cls_hidden_scale", type=int, default=2,
            help="scaler of the intermediate linear layer dimension for mlp classifier")
        self.parser.add_argument("--num_labels", type=int, default=3129,
                                 help="#labels/output-dim for classifier")
        return self.parse_args()

    def get_video_qa_args(self):
        self.parser.add_argument(
            "--task", type=str,
            choices=["action", "transition", "frameqa", "msrvtt_qa"],
            help="TGIF-QA tasks and MSRVTT-QA")
        self.parser.add_argument("--loss_type", type=str, default="ce",
                                 help="loss type, will be overwritten later")
        self.parser.add_argument("--classifier", type=str, default="mlp",
                                 choices=["mlp", "linear"],
                                 help="classifier type")
        self.parser.add_argument(
            "--cls_hidden_scale", type=int, default=2,
            help="scaler of the intermediate linear layer dimension for mlp classifier")
        # for frameQA msrvtt_qa
        self.parser.add_argument("--ans2label_path", type=str, default=None,
                                 help="path to {answer: label} file")

        # manually setup config by task type
        args = self.parse_args()
        if args.max_n_example_per_group != 1:
            Warning(f"For TGIF-QA, most GIF is only paired with a single example, no need to"
                    f"use max_n_example_per_group={args.max_n_example_per_group}"
                    f"larger than 1. Automatically reset to 1.")
            args.max_n_example_per_group = 1
        if os.path.exists(args.ans2label_path):
            num_answers = len(json.load(open(args.ans2label_path, "r")))
        else:
            num_answers = 0

        if args.task in ["msrvtt_qa", "msvd_qa"]:
            args.num_labels = max(num_answers, 1500)
            args.loss_type = "ce"
        else:
            raise NotImplementedError
        return args


shared_configs = SharedConfigs()


================================================
FILE: src/datasets/data_utils.py
================================================
import torch
import random
import torchvision.transforms as transforms
from torchvision.transforms.functional import pad as img_pad
from torchvision.transforms.functional import resize as img_resize
from torch.nn.functional import interpolate as img_tensor_resize
from torch.nn.functional import pad as img_tensor_pad
from torch.nn.modules.utils import _quadruple
from src.utils.basic_utils import flat_list_of_lists
import numbers
import numpy as np
from PIL import Image
_pil_interpolation_to_str = {
    Image.NEAREST: 'PIL.Image.NEAREST',
    Image.BILINEAR: 'PIL.Image.BILINEAR',
    Image.BICUBIC: 'PIL.Image.BICUBIC',
    Image.LANCZOS: 'PIL.Image.LANCZOS',
    Image.HAMMING: 'PIL.Image.HAMMING',
    Image.BOX: 'PIL.Image.BOX',
}


def mask_batch_text_tokens(
        inputs, tokenizer, mlm_probability=0.15, is_train=True):
    """ modified from transformers.data.data_collator
    Args:
        inputs: (B, L), 2D torch.Tensor, does not work for 1D. It has already been padded.
        tokenizer:
        mlm_probability: float
        is_train: if True use random masking, else mask tokens at fixed position to remove randomness in evaluation.
    """
    if tokenizer.mask_token is None:
        raise ValueError(
            "This tokenizer does not have a mask token which is necessary for masked language modeling. "
            "Remove the --mlm flag if you want to use this tokenizer."
        )

    labels = inputs.clone()
    # We sample a few tokens in each sequence for masked-LM training
    # (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
    probability_matrix = torch.full(labels.shape, mlm_probability)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(
            val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(
        special_tokens_mask, dtype=torch.bool), value=0.0)
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    indices_replaced = torch.bernoulli(
        torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(
        tokenizer.mask_token)

    # 10% of the time, we replace masked input tokens with random word
    indices_random = torch.bernoulli(
        torch.full(labels.shape, 0.5)
        ).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(
        len(tokenizer), labels.shape,
        dtype=torch.long)  # len(tokenizer) == #vocab
    inputs[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
    return inputs, labels


def select_batch_text_pivots(
        inputs, tokenizer, ent2id, mpm_probability=1.0, is_train=True):
    """ Given a input text sequence, generate masks and prototype labels such that:
    1) not to mask special token ([CLS], [SEP], [MASK], [PAD]);
    2) always mask all BPE in a word together.

    Args:
    """
    if tokenizer.mask_token is None:
        raise ValueError(
            "This tokenizer does not have a mask token which is necessary for masked language modeling. "
            "Remove the --mlm flag if you want to use this tokenizer."
        )

    labels = inputs.clone()
    # We sample a few tokens in each sequence for as pivots
    probability_matrix = torch.full(labels.shape, mpm_probability)
    # ignore [CLS] [SEP] [MASK] tokens
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(
            val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    # ignore [PAD] tokens
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
    
    # create masking indices
    pivot_indices = torch.bernoulli(probability_matrix).bool()
    labels[special_tokens_mask] = -100  # We only compute loss on masked tokens
    labels[~pivot_indices] = -100  # We only compute loss on masked tokens

    # selected pivot positions: (1) non-special token; (2) selected based on mpm probability. 
    text_pivots_pos = (labels > 0).nonzero()
    
    for tpp in text_pivots_pos:

        orig_tpp = tpp.clone()
        
        bth = tpp[0]
        orig_text_pos = orig_tpp[1]

        text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]]])[0]
        next_text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]+1]])[0] if tpp[1]+1 < inputs.shape[1] else None

        # TODO may consider support encoding beyond sentencepiece. 
        if text_token.startswith('##'):
            # if it is a byte pair, backtrace until we find the prefix
            orig_text_token = ''

            while True:
                if not text_token.startswith('##'):
                    orig_text_token = text_token + orig_text_token

                    break
                else:
                    orig_text_token = text_token[2:] + orig_text_token

                    tpp[1] -= 1

                text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]]])[0]
            
            try:
                # assign prototype labels to all the sentencepiece bytes in the pivot word
                labels[bth][tpp[1]: orig_text_pos + 1] = ent2id[orig_text_token]
                pivot_indices[bth][tpp[1]: orig_text_pos + 1] = True
            except KeyError:
                # we do not have this word for prototype
                labels[bth][orig_text_pos] = -100

        elif next_text_token is not None and next_text_token.startswith('##'):
            # if it is a prefix, forward-trace until we find the end of the byte pair
            full_text_token = text_token 

            while True:
                tpp[1] += 1
                text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]]])[0]

                if not text_token.startswith('##'):
                    # find the next prefix/word
                    break
                else:
                    # find continuing bytes
                    full_text_token = full_text_token + text_token[2:]

            try:
                # assign prototype labels to all the sentencepiece bytes in the pivot word
                labels[bth][orig_text_pos: tpp[1]] = ent2id[full_text_token]
                pivot_indices[bth][orig_text_pos: tpp[1]] = True
            except KeyError:
                # we do not have this word for prototype
                labels[bth][orig_text_pos] = -100

        else:
            # the word is treated in whole by BERT tokenizer
            try:
                labels[bth][tpp[1]] = ent2id[text_token]
            except KeyError:
                # we do not have this word for prototype
                labels[bth][tpp[1]] = -100

    # restore mask if the word is not in the entity list
    pivot_indices[labels==-100] = False

    return pivot_indices, labels


def image_to_tensor(image: np.ndarray, keepdim: bool = True) -> torch.Tensor:
    """Converts a numpy image to a PyTorch 4d tensor image.
    Args:
        image (numpy.ndarray): image of the form :math:`(H, W, C)`, :math:`(H, W)` or
            :math:`(B, H, W, C)`.
        keepdim (bool): If ``False`` unsqueeze the input image to match the shape
            :math:`(B, H, W, C)`. Default: ``True``
    Returns:
        torch.Tensor: tensor of the form :math:`(B, C, H, W)` if keepdim is ``False``,
            :math:`(C, H, W)` otherwise.
    """
    if not isinstance(image, (np.ndarray,)):
        raise TypeError("Input type must be a numpy.ndarray. Got {}".format(
            type(image)))

    if len(image.shape) > 4 or len(image.shape) < 2:
        raise ValueError(
            "Input size must be a two, three or four dimensional array")

    input_shape = image.shape
    tensor: torch.Tensor = torch.from_numpy(image)

    if len(input_shape) == 2:
        # (H, W) -> (1, H, W)
        tensor = tensor.unsqueeze(0)
    elif len(input_shape) == 3:
        # (H, W, C) -> (C, H, W)
        tensor = tensor.permute(2, 0, 1)
    elif len(input_shape) == 4:
        # (B, H, W, C) -> (B, C, H, W)
        tensor = tensor.permute(0, 3, 1, 2)
        keepdim = True  # no need to unsqueeze
    else:
        raise ValueError(
            "Cannot process image with shape {}".format(input_shape))

    return tensor.unsqueeze(0) if not keepdim else tensor


def get_padding(image, max_w, max_h, pad_all=False):
    # keep the images to upper-left corner
    if isinstance(image, torch.Tensor):
        h, w = image.shape[-2:]
    else:
        w, h = image.size
    h_padding, v_padding = max_w - w, max_h - h
    if pad_all:
        h_padding /= 2
        v_padding /= 2
        l_pad = h_padding if h_padding % 1 == 0 else h_padding+0.5
        t_pad = v_padding if v_padding % 1 == 0 else v_padding+0.5
        r_pad = h_padding if h_padding % 1 == 0 else h_padding-0.5
        b_pad = v_padding if v_padding % 1 == 0 else v_padding-0.5
    else:
        l_pad, t_pad = 0, 0
        r_pad, b_pad = h_padding, v_padding
    if isinstance(image, torch.Tensor):
        padding = (int(l_pad), int(r_pad), int(t_pad), int(b_pad))
    else:
        padding = (int(l_pad), int(t_pad), int(r_pad), int(b_pad))
    return padding


class ImagePad(object):
    def __init__(self, max_w, max_h, fill=0, padding_mode='constant'):
        assert isinstance(fill, (numbers.Number, str, tuple))
        assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
        self.max_w = max_w
        self.max_h = max_h
        self.fill = fill
        self.padding_mode = padding_mode

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be padded.

        Returns:
            PIL Image: Padded image.
        """
        if isinstance(img, torch.Tensor):
            paddings = _quadruple(get_padding(img, self.max_w, self.max_h))
            return img_tensor_pad(
                img, paddings,
                self.padding_mode, self.fill)
        return img_pad(
            img, get_padding(img, self.max_w, self.max_h),
            self.fill, self.padding_mode)

    def __repr__(self):
        return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
            format(self.fill, self.padding_mode)


def get_resize_size(image, max_size):
    """
    Args:
        image: PIL Image or torch.tensor
        max_size:

    Returns:

    Note the height/width order difference
    >>> pil_img = Image.open("raw_img_tensor.jpg")
    >>> pil_img.size
    (640, 480)  # (width, height)
    >>> np_img = np.array(pil_img)
    >>> np_img.shape
    (480, 640, 3)  # (height, width, 3)
    """
    # note the order of height and width for different inputs
    if isinstance(image, torch.Tensor):
        # width, height = image.shape[-2:]
        height, width = image.shape[-2:]
    else:
        width, height = image.size

    if height >= width:
        ratio = width*1./height
        new_height = max_size
        new_width = new_height * ratio
    else:
        ratio = height*1./width
        new_width = max_size
        new_height = new_width * ratio
    size = (int(new_height), int(new_width))
    return size

class VideoRandomSquareCrop(object):
    def __init__(self, crop_size, p=0.5):
        assert isinstance(crop_size, int)
        self.crop_size = crop_size
        self.p = p

    def __call__(self, video):
        """
        Args:
            img (torch.tensor): video to be cropped.

        Returns:
            torch.tensor: cropped video.
        """
        if isinstance(video, torch.Tensor):
            if len(video.shape) == 4:
                b, t, h, w = video.shape
            else:
                raise RuntimeError('Expecting 4-dimensional tensor of shape (b,t,h,w), got {}'.format(video.shape))

            # if random.uniform(0, 1) < self.p:
            #     video = torch.flip(video, (3,))

            x = random.randint(0, h - self.crop_size)
            y = random.randint(0, w - self.crop_size)

            return video[:, :, x: x + self.crop_size, y: y + self.crop_size]

        else:
            raise NotImplementedError('Support only torch.Tensor as input, got {}'.format(type(video)))


class VideoResizeSquare(object):
    def __init__(self, out_size, interpolation='nearest'):
        assert isinstance(out_size, int)
        self.out_size = out_size
        self.interpolation = interpolation

    def __call__(self, video):
        """
        Args:
            img (torch.tensor): video to be scaled.

        Returns:
            torch.tensor: Rescaled video.
        """
        if isinstance(video, torch.Tensor):
            if len(video.shape) == 4:
                t, c, h, w = video.shape
                assert c == 3, 'Expecting 3-channel color video, got video of shape {}'.format(video.shape)
            else:
                raise RuntimeError('Expecting 4-dimensional tensor of shape (b,t,h,w), got {}'.format(video.shape))

            short_side = h if h < w else w
            # scaling_factor = self.out_size / short_side

            # new_h = int(h * scaling_factor)
            # new_w = int(w * scaling_factor)

            resized_video = img_tensor_resize(video, size=((self.out_size, self.out_size)), mode=self.interpolation)
            
            return resized_video


        else:
            # in other data class, the order of shape might be different.
            raise NotImplementedError('Support only torch.Tensor as input, got {}'.format(type(video)))

    def __repr__(self):
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(
            self.out_size, self.interpolation)


class ImageResize(object):
    """Resize the input image (torch.tensor) to the given size.

    Args:
        max_size (int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, max_size, interpolation=Image.BILINEAR):
        assert isinstance(max_size, int)
        self.max_size = max_size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (torch.tensor): Image to be scaled.

        Returns:
            torch.tensor: Rescaled image.
        """
        if isinstance(img, torch.Tensor):
            assert isinstance(self.interpolation, str)
            return img_tensor_resize(
                img, size=get_resize_size(img, self.max_size),
                mode=self.interpolation, align_corners=False)
        return img_resize(
            img, get_resize_size(img, self.max_size), self.interpolation)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(
            self.size, interpolate_str)


def get_imagenet_transform(min_size=600, max_size=1000):
    """parameters from https://github.com/pytorch/examples/blob/master/imagenet/main.py
    This simply crop the center square from the image
    """
    if min_size != 600:
        import warnings
        warnings.warn(f'Warning: min_size is not used in image transform, '
                      f'setting min_size will have no effect.')
    return transforms.Compose([
        ImageResize(max_size, Image.BILINEAR),  # longer side will be resized to 1000
        ImagePad(max_size, max_size),  # pad to 1000 * 1000
    ])


class ImageNorm(object):
    """Apply Normalization to Image Pixels on GPU
    """
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean).cuda().view(1, 1, 3, 1, 1)
        self.std = torch.tensor(std).cuda().view(1, 1, 3, 1, 1)
        # assert max(std) <= 1 and min(std) >= 0\
        #     or max(mean) <= 1 and min(mean) >= 0,\
        #         "Please provide mean or std within range [0, 1]"

    def __call__(self, img):
        """
        Args:
            img: float image tensors, (B, N, 3, H, W)

        Returns:
            img: normalized float image tensors
        """
        if torch.max(img) > 1 and self.mean.max() <= 1:
            img.div_(255.)
        return img.sub_(self.mean).div_(self.std)


def chunk_list(examples, chunk_size=2, pad_to_divisible=True):
    """
    Args:
        examples: iterable, examples grouped by image/video
        chunk_size: int, number of examples in each chunk.
        pad_to_divisible: bool, pad the examples to be divisible by chunk_size.
    >>> test_examples = [3, 4, 5, 6, 7]
    >>> chunk_list(test_examples, chunk_size=2, pad_to_divisible=True)
    [[3, 4], [5, 6], [7, 7]]  # the lst element has some randomness
    >>> chunk_list(test_examples, chunk_size=2, pad_to_divisible=False)
    [[3, 4], [5, 6], [7]]
    """
    n_examples = len(examples)
    remainder = n_examples % chunk_size
    if pad_to_divisible and remainder > 0:
        n_pad = chunk_size - remainder
        pad = random.choices(examples, k=n_pad)  # with replacement
        examples = examples + pad
        n_examples = len(examples)
        remainder = 0
    chunked_examples = []
    n_chunks = int(n_examples / chunk_size)
    n_chunks = n_chunks + 1 if remainder > 0 else n_chunks
    for i in range(n_chunks):
        chunked_examples.append(examples[i*chunk_size: (i+1)*chunk_size])
    return chunked_examples


def mk_input_group(key_grouped_examples, max_n_example_per_group=1, is_train=True,
                   example_unique_key=None):
    """ Re-organize examples into groups. Each input group will have a single image paired
    with X (X=max_n_example_per_img) examples. Images with total #examples > X will be
    split into multiple groups. In the case a group has < X examples, we will copy
    the examples to make the group has X examples.
    Args:
        key_grouped_examples: dict, each key is image/video id,
            each value is a list(example) associated with this image/video
        max_n_example_per_group: int, pair max #examples with each image/video.
           Note that each image can have multiple groups.
        is_train: bool, if True, copy the examples to make sure each input
            group has max_n_example_per_group examples.
        example_unique_key: str, used to make sure no inputs are discarded by matching
            the input and output ids specified by `example_unique_key`
    """
    input_groups = []  # each element is (id, list(example))
    for k, examples in key_grouped_examples.items():
        chunked_examples = chunk_list(examples,
                                      chunk_size=max_n_example_per_group,
                                      pad_to_divisible=is_train)
        for c in chunked_examples:
            # if len(c) == 0:
            #     continue
            input_groups.append((k, c))

    if example_unique_key is not None:
        print(f"Using example_unique_key {example_unique_key} to check whether input and output ids m")
        # sanity check: make sure we did not discard any input example by accident.
        input_question_ids = flat_list_of_lists(
            [[sub_e[example_unique_key] for sub_e in e] for e in key_grouped_examples.values()])
        output_question_ids = flat_list_of_lists(
            [[sub_e[example_unique_key] for sub_e in e[1]] for e in input_groups])
        assert set(input_question_ids) == set(output_question_ids), "You are missing "
    return input_groups


# def repeat_tensor_rows(raw_tensor, row_repeats):
#     """ repeat raw_tensor[i] row_repeats[i] times.
#     Args:
#         raw_tensor: (B, *)
#         row_repeats: list(int), len(row_repeats) == len(raw_tensor)
#     """
#     assert len(raw_tensor) == len(raw_tensor), "Has to be the same length"
#     if sum(row_repeats) == len(row_repeats):
#         return raw_tensor
#     else:
#         indices = torch.LongTensor(
#             flat_list_of_lists([[i] * r for i, r in enumerate(row_repeats)])
#         ).to(raw_tensor.device)
#         return raw_tensor.index_select(0, indices)





================================================
FILE: src/datasets/dataloader.py
================================================
"""
modified from UNITER codebase

A meta data loader for sampling from different datasets / training tasks
A prefetch loader to speedup data loading
"""
import random

import torch
from torch.utils.data import DataLoader
from src.utils.distributed import any_broadcast


class MetaLoader(object):
    """ wraps multiple data loader """
    def __init__(self, loaders, accum_steps=1, distributed=False):
        assert isinstance(loaders, dict)
        self.name2loader = {}
        self.name2iter = {}
        self.sampling_pools = []
        n_batches_in_epoch = 0
        for n, l in loaders.items():
            if isinstance(l, tuple):
                l, r = l
            elif isinstance(l, DataLoader):
                r = 1
            else:
                raise ValueError()
            n_batches_in_epoch += len(l.dataset) * r / l.batch_size
            self.name2loader[n] = l
            self.name2iter[n] = iter(l)
            self.sampling_pools.extend([n]*r)
        self.n_batches_in_epoch = n_batches_in_epoch
        self.accum_steps = accum_steps
        self.distributed = distributed
        self.step = 0

    def __iter__(self):
        """ this iterator will run indefinitely """
        task = self.sampling_pools[0]
        while True:
            if self.step % self.accum_steps == 0:
                task = random.choice(self.sampling_pools)
                if self.distributed:
                    # make sure all process is training same task
                    task = any_broadcast(task, 0)
            self.step += 1
            iter_ = self.name2iter[task]
            try:
                batch = next(iter_)
            except StopIteration:
                iter_ = iter(self.name2loader[task])
                batch = next(iter_)
                self.name2iter[task] = iter_

            yield task, batch


def move_to_cuda(batch):
    if isinstance(batch, torch.Tensor):
        return batch.cuda(non_blocking=True)
    elif isinstance(batch, list):
        new_batch = [move_to_cuda(t) for t in batch]
    elif isinstance(batch, tuple):
        new_batch = tuple(move_to_cuda(t) for t in batch)
    elif isinstance(batch, dict):
        new_batch = {n: move_to_cuda(t) for n, t in batch.items()}
    else:
        return batch
    return new_batch


def record_cuda_stream(batch):
    if isinstance(batch, torch.Tensor):
        batch.record_stream(torch.cuda.current_stream())
    elif isinstance(batch, list) or isinstance(batch, tuple):
        for t in batch:
            record_cuda_stream(t)
    elif isinstance(batch, dict):
        for t in batch.values():
            record_cuda_stream(t)
    else:
        pass


class PrefetchLoader(object):
    """
    overlap compute and cuda data transfer
    (copied and then modified from nvidia apex)
    """
    def __init__(self, loader, img_normalize=None):
        self.loader = loader
        self.stream = torch.cuda.Stream()
        self.img_normalize = img_normalize

    def __iter__(self):
        loader_it = iter(self.loader)
        self.preload(loader_it)
        batch = self.next(loader_it)
        while batch is not None:
            is_tuple = isinstance(batch, tuple)
            if is_tuple:
                task, batch = batch
            batch["visual_inputs"] = batch["visual_inputs"].float()
            if self.img_normalize is not None:
                batch["visual_inputs"] = self.img_normalize(
                    batch["visual_inputs"])
                if "crop_visual_inputs" in batch:
                    batch["crop_visual_inputs"] = batch["crop_visual_inputs"].float()
                    batch["crop_visual_inputs"] = self.img_normalize(
                        batch["crop_visual_inputs"])
                if "context_visual_inputs" in batch:
                    batch["context_visual_inputs"] = batch["context_visual_inputs"].float()
                    batch["context_visual_inputs"] = self.img_normalize(
                        batch["context_visual_inputs"])
            if is_tuple:
                yield task, batch
            else:
                yield batch
            batch = self.next(loader_it)

    def __len__(self):
        return len(self.loader)

    def preload(self, it):
        try:
            self.batch = next(it)
        except StopIteration:
            self.batch = None
            return
        # if record_stream() doesn't work, another option is to make sure
        # device inputs are created on the main stream.
        # self.next_input_gpu = torch.empty_like(self.next_input,
        #                                        device='cuda')
        # self.next_target_gpu = torch.empty_like(self.next_target,
        #                                         device='cuda')
        # Need to make sure the memory allocated for next_* is not still in use
        # by the main stream at the time we start copying to next_*:
        # self.stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self.stream):
            self.batch = move_to_cuda(self.batch)
            # more code for the alternative if record_stream() doesn't work:
            # copy_ will record the use of the pinned source tensor in this
            # side stream.
            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
            # self.next_input = self.next_input_gpu
            # self.next_target = self.next_target_gpu

    def next(self, it):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch = self.batch
        if batch is not None:
            record_cuda_stream(batch)
        self.preload(it)
        return batch

    def __getattr__(self, name):
        method = self.loader.__getattribute__(name)
        return method


class InfiniteIterator(object):
    """iterate an iterable oobject infinitely"""
    def __init__(self, iterable):
        self.iterable = iterable
        self.iterator = iter(iterable)

    def __iter__(self):
        while True:
            try:
                batch = next(self.iterator)
            except StopIteration:
                self.iterator = iter(self.iterable)
                batch = next(self.iterator)
            yield batch


================================================
FILE: src/datasets/dataset_base.py
================================================
from torch.utils.data import Dataset
from PIL import Image
import io
import av
import torch
import numpy as np
import lmdb
import random
import decord
from decord import VideoReader
from src.datasets.data_utils import (
    ImageResize, ImagePad, image_to_tensor)
from src.utils.load_save import LOGGER

decord.bridge.set_bridge("torch")


class AlproBaseDataset(Dataset):
    """
    datalist: list(dicts)  # lightly pre-processed
        {
        "type": "image",
        "filepath": "/abs/path/to/COCO_val2014_000000401092.jpg",
        "text": "A plate of food and a beverage are on a table.",
                # should be tokenized and digitized first?
        ...
        }
    tokenizer:
    max_img_size: int,
    max_txt_len: int, max text sequence length, including special tokens.
    fps: float, frame per second
    num_frm: #frames to use as input.
    """

    def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type='lmdb', fps=3, num_frm=3,
                 frm_sampling_strategy="rand", max_img_size=-1, max_txt_len=20):
        self.fps = fps
        self.num_frm = num_frm
        self.frm_sampling_strategy = frm_sampling_strategy
        self.datalist = datalist
        self.tokenizer = tokenizer
        self.max_txt_len = max_txt_len
        self.max_img_size = max_img_size
        self.img_resize = ImageResize(
            max_img_size,
            "bilinear")  # longer side will be resized to 1000
        self.img_pad = ImagePad(
            max_img_size, max_img_size)  # pad to 1000 * 1000

        self.img_db_type = img_db_type

        assert img_db_type in ['lmdb', 'rawvideo'], "Invalid type for img_db_type, expected {'lmdb', 'rawvideo'}, found {}.".format(img_db_type)

        if self.img_db_type == 'lmdb':
            self.env = lmdb.open(
                img_lmdb_dir, readonly=True,
                create=False)  # readahead=not _check_distributed()
            self.txn = self.env.begin(buffers=True)
        else:
            self.img_db_dir = img_lmdb_dir

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, index):
        raise NotImplementedError

    def _load_img(self, img_id):
        """Load and apply transformation to image

        Returns:
            torch.float, in [0, 255], (n_frm=1, c, h, w)
        """
        raw_img = load_decompress_img_from_lmdb_value(
            self.txn.get(str(img_id).encode("utf-8"))
        )
        image_np = np.array(raw_img, dtype=np.uint8)  # (h, w, c)
        raw_img_tensor = image_to_tensor(
            image_np, keepdim=False).float()  # (c, h, w) [0, 255]
        resized_img = self.img_resize(raw_img_tensor)
        transformed_img = self.img_pad(
            resized_img)  # (n_frm=1, c, h, w)
        return transformed_img

    @classmethod
    def _is_extreme_aspect_ratio(cls, tensor, max_ratio=5.):
        """ find extreme aspect ratio, where longer side / shorter side > max_ratio
        Args:
            tensor: (*, H, W)
            max_ratio: float, max ratio (>1).
        """
        h, w = tensor.shape[-2:]
        return h / float(w) > max_ratio or h / float(w) < 1 / max_ratio

    def _load_video(self, video_id, num_clips=None, clip_idx=None,
                    safeguard_duration=False, video_max_pts=None):
        """Load and sample frames from video.
        Apply transformation to the sampled frames.

        Sample a clip:
            - random: set num_clips and clip_idx to be None
            - uniform: set num_clips=N, clip_idx=idx. e.g., num_clips=3
                and clip_idx=1 will first segment the video into 3 clips,
                then sample the 2nd clip.

        Returns:
            torch.float, in [0, 255], (n_frm=T, c, h, w)
        """
        assert (num_clips is None) == (clip_idx is None), "Both None, or both not None"
        # (T, C, H, W) [0, 255]
        io_stream = io.BytesIO(self.txn.get(str(video_id).encode("utf-8")))
        raw_sampled_frms, video_max_pts = extract_frames_from_video_binary(
            io_stream,
            target_fps=self.fps,
            num_frames=self.num_frm,
            multi_thread_decode=False,
            sampling_strategy=self.frm_sampling_strategy,
            num_clips=num_clips,
            clip_idx=clip_idx,
            safeguard_duration=safeguard_duration,
            video_max_pts=video_max_pts
        )

        if raw_sampled_frms is None:
            return None, None
        elif self._is_extreme_aspect_ratio(raw_sampled_frms, max_ratio=5.):
            print(
                f"Found extreme aspect ratio for video id {video_id}. Skip it")
            return None, None

        raw_sampled_frms = raw_sampled_frms.float()
        resized_frms = self.img_resize(raw_sampled_frms)
        padded_frms = self.img_pad(resized_frms)
        return padded_frms, video_max_pts


    def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
        try:
            if not height or not width:
                vr = VideoReader(video_path)
            else:
                vr = VideoReader(video_path, width=width, height=height)

            vlen = len(vr)

            if start_time or end_time:
                assert fps > 0, 'must provide video fps if specifying start and end time.'

                start_idx = min(int(start_time * fps), vlen)
                end_idx = min(int(end_time * fps), vlen)
            else:
                start_idx, end_idx = 0, vlen

            if self.frm_sampling_strategy == 'uniform':
                frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
            elif self.frm_sampling_strategy == 'nlvl_uniform':
                frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm).astype(int)
            elif self.frm_sampling_strategy == 'nlvl_rand':
                frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm).astype(int)

                # generate some random perturbations
                strides = [frame_indices[i] - frame_indices[i-1] for i in range(1, len(frame_indices))] + [vlen - frame_indices[-1]]
                pertube = np.array([np.random.randint(0, stride) for stride in strides])

                frame_indices = frame_indices + pertube

            elif self.frm_sampling_strategy == 'rand':
                frame_indices = sorted(random.sample(range(vlen), self.num_frm))
            elif self.frm_sampling_strategy == 'headtail':
                frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
                frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
                frame_indices = frame_indices_head + frame_indices_tail
            else:
                raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))

            raw_sample_frms = vr.get_batch(frame_indices)
        except Exception as e:
            return None

        raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)

        return raw_sample_frms

def img_collate(imgs):
    """
    Args:
        imgs:

    Returns:
        torch.tensor, (B, 3, H, W)
    """
    w = imgs[0].width
    h = imgs[0].height
    tensor = torch.zeros(
        (len(imgs), 3, h, w), dtype=torch.uint8).contiguous()
    for i, img in enumerate(imgs):
        nump_array = np.array(img, dtype=np.uint8)
        if (nump_array.ndim < 3):
            nump_array = np.expand_dims(nump_array, axis=-1)
        # (H, W, 3) --> (3, H, W)
        nump_array = np.rollaxis(nump_array, 2)
        tensor[i] += torch.from_numpy(nump_array)
    return tensor


================================================
FILE: src/datasets/dataset_pretrain_sparse.py
================================================
import os
import json
import random

import torch
import spacy
from torch.utils.data.dataloader import default_collate
from src.utils.logger import LOGGER
from src.utils.basic_utils import flat_list_of_lists, save_frames_grid
from src.datasets.data_utils import VideoRandomSquareCrop, VideoResizeSquare, mask_batch_text_tokens, select_batch_text_pivots
from src.datasets.dataset_base import AlproBaseDataset, img_collate

from src.datasets.randaugment import TemporalConsistentRandomAugment, RandomAugment

from torch.utils.data import Dataset

from torchvision import transforms
from PIL import Image
import numpy as np


class AlproPretrainSparseDataset(AlproBaseDataset):
    """
    datalist: list(tuples)  each tuple is (img_id, list(dicts)),
        each dict {
            "type": "image",
            "filepath": "/abs/path/to/COCO_val2014_000000401092.jpg",
            "text": "A plate of food and a beverage are on a table.",  # should be tokenized and digitized first?
            ...
            }
    tokenizer:
    max_img_size: int,
    max_txt_len: int, max text sequence length, including special tokens.
    vis_format: str, image or video, used to decide data loading method.
    """
    def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type, txt_dir,
                video_fmt='.mp4', crop_size=256, resize_size=288, fps=3, num_frm=3, frm_sampling_strategy="rand",
                max_img_size=1000, max_txt_len=20,
                use_itm=True, is_train=True):
        super(AlproPretrainSparseDataset, self).__init__(
            datalist, tokenizer, img_lmdb_dir, 
            img_db_type=img_db_type,
            fps=fps, 
            num_frm=num_frm, 
            frm_sampling_strategy=frm_sampling_strategy,
            max_img_size=max_img_size, 
            max_txt_len=max_txt_len)
        self.use_itm = use_itm

        self.txt_dir = txt_dir
        self.video_fmt = video_fmt

        self.crop_size = crop_size
        self.video_random_cropper = VideoRandomSquareCrop(crop_size)

        self.resize_size = resize_size

        self.is_train = is_train

        if self.is_train:
            self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip'])     
        else:
            self.randaug = None

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, index):
        start_time = None
        end_time = None

        # fetch video
        num_retries = 10  # skip error videos

        for _ in range(num_retries):
            data_sample = self.datalist.iloc[index]

            video_id = str(data_sample.video_id)
            txt_len = int(data_sample.txt_len)

            if hasattr(data_sample, 'text'):
                text = data_sample.text.strip()
            else:
                raise NotImplementedError("Un-supported text annotation format.")

            # fetch video
            video_path = os.path.join(self.img_db_dir, video_id + self.video_fmt) 

            # read with retries
            for i in range(3):
                img_array = self._load_video_from_path_decord(video_path, height=self.resize_size, width=self.resize_size)

                if img_array is not None:
                    break

            if img_array is not None:
                t, c, h, w = img_array.shape

            # Select a random video if the current video was not able to access.
            if img_array is None:
                LOGGER.info(f"Failed to load examples with video: {video_path}. "
                            f"Will randomly sample an example as a replacement.")
                index = random.randint(0, len(self) - 1)
                continue
            else:
                # square crop
                img_array = self.video_random_cropper(img_array)

                if self.randaug:
                    img_array = self.randaug(img_array.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

                break
        else:
            raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
        
        examples = [{'text_str': text, 'itm_label': 1}]

        return dict(
            img=img_array,  # (T, C, H, W)
            examples=examples,
            n_examples=len(examples),  # used to create image feature copies.
            type='video'
        )

class PretrainImageTextDataset(Dataset):
    def __init__(self, datalist, tokenizer, is_train=True, crop_size=256, resize_size=288, num_frm=4, max_txt_len=40):
        self.datalist = datalist
        self.max_txt_len = max_txt_len

        self.crop_size = crop_size
        self.resize_size = resize_size
        self.num_frms = num_frm

        self.is_train = is_train

        self.transform = transforms.Compose([                        
                transforms.RandomResizedCrop(self.crop_size, scale=(0.2, 1.0), interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
                RandomAugment(2,7,isPIL=True,augs=['Identity','Brightness','Sharpness',
                                                'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate'])     
            ])    
        
    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, index):
        start_time = None
        end_time = None

        # fetch video
        num_retries = 10  # skip error videos

        for _ in range(num_retries):
            data_sample = self.datalist[index]

            try:
                if type(data_sample['caption']) == list:
                    text = random.choice(data_sample['caption'])
                else:
                    text = data_sample['caption']
            
                img_path = data_sample['image']
                img_arr = Image.open(img_path).convert('RGB')   
                img_arr = self.transform(img_arr)
                img_arr = np.asarray(img_arr, dtype=np.float32).transpose(2, 0, 1)
                img_arr = torch.from_numpy(img_arr).unsqueeze(0)
                img_arr = img_arr.repeat(self.num_frms, 1, 1, 1)

            except Exception as e:
                img_arr = None

            if img_arr is not None:
                t, c, h, w = img_arr.shape

            # Select a random video if the current video was not able to access.
            if img_arr is None:
                LOGGER.info(f"Failed to load examples with image: {img_path}. "
                            f"Will randomly sample an example as a replacement.")
                index = random.randint(0, len(self) - 1)
                continue
            else:
                break
        else:
            raise RuntimeError(f"Failed to fetch image after {num_retries} retries.")
        
        examples = [{'text_str': text, 'itm_label': 1}]

        return dict(
            img=img_arr,  # (T, C, H, W)
            examples=examples,
            n_examples=len(examples),  # used to create image feature copies.
            type='img'
        )


class PretrainCollator(object):
    """is_train is kept here if we want to remove
    the randomness during validation of MLM accuracy.
    In that case, instantiate two PretrainCollator"""
    def __init__(self, tokenizer, 
                 mlm=True, mlm_probability=0.15,
                 patch_size=16,
                 mpm=True,
                 max_length=20, is_train=True):
        self.tokenizer = tokenizer
        self.mlm = mlm
        self.mlm_probability = mlm_probability
        self.max_length = max_length
        self.is_train = is_train

        self.mpm = mpm
        self.patch_size = patch_size

    def collate_batch(self, batch):
        if isinstance(batch[0]["img"], torch.Tensor):
            v_collate = default_collate
        else:
            v_collate = img_collate
        visual_inputs = v_collate([d["img"] for d in batch])  # (B, #frm=1 or T, 3, H, W)
        # group data
        text_examples = flat_list_of_lists([d["examples"] for d in batch])
        n_examples_list = [d["n_examples"] for d in batch]  # (B, )
        # group elements data
        batch_enc = self.tokenizer.batch_encode_plus(
            [d["text_str"] for d in text_examples],
            max_length=self.max_length,
            padding='max_length',
            return_tensors="pt",
            truncation=True
        )
        text_input_ids = batch_enc.input_ids  # (B, L)
        text_input_ids_no_mask = text_input_ids.clone()

        if self.mlm:
            text_input_ids, mlm_labels = mask_batch_text_tokens(
                text_input_ids, self.tokenizer,
                is_train=self.is_train)  # make mlm data
        else:
            text_input_ids, mlm_labels = text_input_ids, None
        
        text_input_mask = batch_enc.attention_mask  # (B, L)
        itm_labels = default_collate(
            [d["itm_label"] for d in text_examples])  # (B, )
        
        erase_elems = [random_erase(e, patch_size=self.patch_size) for e in visual_inputs.clone()]

        if self.mpm:
            crop_visual_inputs = v_collate([elems[0] for elems in erase_elems])
            mpm_masks = v_collate([elems[1] for elems in erase_elems])
            context_visual_inputs = v_collate([elems[2] for elems in erase_elems])

            return dict(
                visual_inputs=visual_inputs,  # (B, #frm=1 or T, H, W, C)
                crop_visual_inputs=crop_visual_inputs,
                context_visual_inputs=context_visual_inputs,
                mpm_mask=mpm_masks,
                text_input_ids=text_input_ids_no_mask,
                mlm_text_input_ids=text_input_ids,
                mlm_labels=mlm_labels,
                text_input_mask=text_input_mask, # used to exclude [PAD] token
                itm_labels=itm_labels,
                n_examples_list=n_examples_list,  # used to create image feature copies.
                type=batch[0]['type']
            )
        else:
            return dict(
                visual_inputs=visual_inputs,  # (B, #frm=1 or T, H, W, C)
                text_input_ids=text_input_ids_no_mask,
                mlm_text_input_ids=text_input_ids,
                mlm_labels=mlm_labels,
                text_input_mask=text_input_mask, # used to exclude [PAD] token
                itm_labels=itm_labels,
                n_examples_list=n_examples_list,  # used to create image feature copies.
                type=batch[0]['type']
            )

def random_erase(input_img, patch_size, s_l=0.3, s_h=0.5, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255):
    assert input_img.ndim == 4
    img_t, img_c, img_h, img_w = input_img.shape

    while True:
        s = np.random.uniform(s_l, s_h) * img_h * img_w
        r = np.random.uniform(r_1, r_2)
        w = int(np.sqrt(s / r))
        h = int(np.sqrt(s * r))
        left = np.random.randint(0, img_w)
        top = np.random.randint(0, img_h)

        w = w - w % patch_size
        h = h - h % patch_size

        left = left - left % patch_size
        top = top - top % patch_size

        if left + w <= img_w and top + h <= img_h:
            break

    context_img = input_img.clone()
    context_img[:, :, top: top + h, left: left + w] = 0

    input_img = input_img[:, :, top: top + h, left: left + w]
    pad = (left, img_w - left - w, top, img_h - top - h)
    input_img = torch.nn.functional.pad(input_img, pad=pad, mode='constant', value=0.0)

    img_masks = torch.ones_like(input_img)
    img_masks[:, :, top: top+h, left: left+w] = 0

    img_masks = torch.nn.functional.avg_pool2d(img_masks.float(), kernel_size=(patch_size, patch_size), stride=patch_size)
    img_masks = torch.mean(img_masks, dim=(0, 1))

    return input_img, img_masks, context_img

================================================
FILE: src/datasets/dataset_video_qa.py
================================================
import os
import torch
import random
import numpy as np
import copy
from torch.utils.data.dataloader import default_collate
from src.utils.basic_utils import flat_list_of_lists
from src.utils.load_save import LOGGER
from src.datasets.dataset_base import AlproBaseDataset
from src.datasets.randaugment import TemporalConsistentRandomAugment


class AlproVideoQADataset(AlproBaseDataset):
    """ This should work for both train and test (where labels are not available).
    task_type: str, one of [action, frameqa, transition]
        where action and transition are multiple-choice QA,
            frameqa is opened QA similar to VQA.
    datalist: list(tuples)  each tuple is (img_id, list(dicts)),
        each dict
    tokenizer:
    max_img_size: int,
    max_txt_len: int, max text sequence length, including special tokens.
    return_label: bool, whether return label in __getitem__
    random_sample_clips:
    """
    open_ended_qa_names = ["frameqa", "msrvtt_qa", "msvd_qa"]

    def __init__(self, task_type, datalist, tokenizer, img_lmdb_dir,
                 fps=3, num_frm=3, frm_sampling_strategy="rand",
                 max_img_size=1000, max_txt_len=20, ans2label=None,
                 ensemble_n_clips=1, return_label=True, is_train=False, random_sample_clips=True, 
                 video_fmt='.mp4', img_db_type='lmdb'):
        super(AlproVideoQADataset, self).__init__(
            datalist, tokenizer, img_lmdb_dir, img_db_type=img_db_type,
            fps=fps, num_frm=num_frm,
            frm_sampling_strategy=frm_sampling_strategy,
            max_img_size=max_img_size, max_txt_len=max_txt_len)
        self.ensemble_n_clips = ensemble_n_clips
        self.return_label = return_label
        self.is_train = is_train
        self.task_type = task_type
        self.ans2label = ans2label
        self.num_labels = len(ans2label)
        self.random_sample_clips = random_sample_clips
        self.label2ans = {v: k for k, v in ans2label.items()}
        self.qid2data = {d["question_id"]: d for group in datalist for d in group[1]}

        self.video_fmt = video_fmt

        if self.is_train:
            self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip'])     
        else:
            self.randaug = None

    def __len__(self):
        return len(self.datalist)


    def __getitem__(self, index):
        # skip error videos:
        num_retries = 5
        for _ in range(num_retries):
            vid_id, examples = self.datalist[index]  # one video with multiple examples
            if self.ensemble_n_clips > 1:
                raise NotImplementedError('Do not support multiple clips for now.')
            else:
                video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt) 
                vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)

            # Select a random video if the current video was not able to access.
            if vid_frm_array is None:
                LOGGER.info(f"Failed to load examples with video: {vid_id}. "
                            f"Will randomly sample an example as a replacement.")
                index = random.randint(0, len(self) - 1)
                continue

            if self.randaug:
                vid_frm_array = self.randaug(vid_frm_array.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

            examples = [self._get_single_example(e) for e in examples]
            return dict(
                vid=vid_frm_array,
                examples=examples,
                n_examples=len(examples)  # used to create image feature copies.
            )
        else:
            raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")

    def _get_single_example(self, data):
        example = dict(
            q_str=data["question"],
            question_id=data["question_id"],
            label=data["answer"]
        )
        if self.task_type in self.open_ended_qa_names:
            if self.return_label:
                example["label"] = self.ans2label[example["label"]]
        if not self.return_label:
            example["label"] = None
        return example

    def evaluate_qa(self, results):
        """
        Args:
            results: list(dict),
              each dict is
                {
                    "question_id": int,
                    "answer": int or float, either answer_idx (int)
                }
        Returns:
            TGIF-QA score
        """
        preds = []
        gts = []
        # for frameQA
        answer_types = []
        answer_type2idx = dict(
            frameqa={"object": 0, "number": 1, "color": 2, "location": 3},
            msrvtt_qa={k: idx for idx, k in enumerate(["what", "who", "how", "where", "when"])},
            msvd_qa={k: idx for idx, k in enumerate(["what", "who", "how", "where", "when"])}
        )

        qid2pred_ans = {r["question_id"]: r["answer"] for r in results}
        if self.task_type in self.open_ended_qa_names:  # convert ans_idx, int --> str
            qid2pred_ans = {k: self.label2ans[v] for k, v in qid2pred_ans.items()}

        for qid, pred_ans in qid2pred_ans.items():
            preds.append(pred_ans)

            gt_data = self.qid2data[qid]
            gt_ans = gt_data["answer"]
            if self.task_type in self.open_ended_qa_names:
                answer_types.append(answer_type2idx[self.task_type][gt_data["answer_type"]])
            gts.append(gt_ans)

        preds = np.array(preds)
        gts = np.array(gts)
        metrics = dict()
        # preds and gts are array of strings
        metrics["overall_acc"] = float(np.mean(preds == gts))
        if self.task_type in self.open_ended_qa_names:
            answer_types = np.array(answer_types)
            ratios = dict()
            for ans_type, ans_type_idx in answer_type2idx[self.task_type].items():
                answer_type_mask = answer_types == ans_type_idx
                answer_type_corrects = (
                        preds[answer_type_mask] == gts[answer_type_mask])
                metrics[f"{ans_type}_acc"] = float(
                    np.mean(answer_type_corrects)) if len(answer_type_corrects) != 0 else 0
                ratios[f"{ans_type}_ratio"] = [
                    1. * len(answer_type_corrects) / len(answer_types),
                    len(answer_type_corrects)]
            metrics["ratios"] = ratios
        return metrics


class VideoQACollator(object):
    def __init__(self, tokenizer, max_length=20, task_type="action", n_options=5):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.task_type = task_type
        self.n_options = n_options

    def collate_batch(self, batch):
        v_collate = default_collate
        visual_inputs = v_collate([d["vid"] for d in batch])  # (B, T, 3, H, W)
        # group data
        text_examples = flat_list_of_lists([d["examples"] for d in batch])
        n_examples_list = [d["n_examples"] for d in batch]  # (B, )
        # group elements data
        # directly concatenate question and option as a single seq.
        if self.task_type in ["action", "transition"]:
            text_str_list = flat_list_of_lists(
                [[d["q_str"] + " " + d["options_str_list"][i] for i in range(self.n_options)]
                 for d in text_examples]
            )  # (B * n_options, )
        else:
            text_str_list = [d["q_str"] for d in text_examples]  # (B, )
        batch_enc = self.tokenizer.batch_encode_plus(
            text_str_list,
            max_length=self.max_length,
            padding='max_length',
            return_tensors="pt",
            truncation=True
        )
        text_input_ids = batch_enc.input_ids  # (B, L)
        text_input_mask = batch_enc.attention_mask  # (B, L)

        labels = default_collate([int(d["label"]) for d in text_examples]) \
            if text_examples[0]["label"] is not None else None  # (B, #ans)
        question_ids = [d["question_id"] for d in text_examples]
        return dict(
            visual_inputs=visual_inputs,  # (B, #frm, H, W, C)
            text_input_ids=text_input_ids,
            text_input_mask=text_input_mask,
            question_ids=question_ids,
            labels=labels,
            n_examples_list=n_examples_list  # used to create image feature copies.
        )


================================================
FILE: src/datasets/dataset_video_retrieval.py
================================================
import random
import copy
import os
import torch
import numpy as np
from torch.utils.data.dataloader import default_collate
from src.utils.basic_utils import flat_list_of_lists
from src.utils.load_save import LOGGER
from src.datasets.dataset_base import AlproBaseDataset
from src.datasets.randaugment import TemporalConsistentRandomAugment


class AlproVideoRetrievalDataset(AlproBaseDataset):
    """ This should work for both train and test (where labels are not available).
    datalist: list(tuples)  each tuple is (img_id, list(dicts)),
        each dict
    tokenizer:
    max_img_size: int,
    max_txt_len: int, max text sequence length, including special tokens.
    random_sample_clips: bool, whether using randomly sampled N clips or always use uniformly sampled N clips
    """
    def __init__(self, datalist, tokenizer, img_lmdb_dir,
                 fps=3, num_frm=3, frm_sampling_strategy="rand",
                 max_img_size=1000, max_txt_len=40, itm_neg_size=1,
                 ensemble_n_clips=1, random_sample_clips=True,
                 video_fmt='.mp4', img_db_type='lmdb', is_train=False):
        super(AlproVideoRetrievalDataset, self).__init__(
            datalist, tokenizer, img_lmdb_dir, img_db_type=img_db_type,
            fps=fps, num_frm=num_frm,
            frm_sampling_strategy=frm_sampling_strategy,
            max_img_size=max_img_size, max_txt_len=max_txt_len)
        self.ensemble_n_clips = ensemble_n_clips
        self.num_labels = 2
        self.itm_neg_size = itm_neg_size
        self.random_sample_clips = random_sample_clips
        self.id2data = {
            d["id"]: d for group in datalist for d in group[1]}

        self.is_train = is_train
        self.video_fmt = video_fmt

        if self.is_train:
            self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip'])     
        else:
            self.randaug = None

    def __len__(self):
        return len(self.datalist)

    def __getitem__(self, index):
        # skip error videos:
        num_retries = 5
        for _ in range(num_retries):
            vid_id, examples = self.datalist[index]  # one video with multiple examples
            if self.ensemble_n_clips > 1:
                raise NotImplementedError('Do not support multiple clips for now.')
            else:
                video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt) 
                vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)

            # Select a random video if the current video was not able to access.
            if vid_frm_array is None:
                LOGGER.info(f"Failed to load examples with video: {vid_id}. "
                            f"Will randomly sample an example as a replacement.")
                index = random.randint(0, len(self) - 1)
                continue
            sampled_examples = []
            for e in examples:
                s = self._get_single_example(e, index)
                if isinstance(s, dict):
                    sampled_examples.append(s)
                else:
                    sampled_examples.extend(s)
            return dict(
                vid=vid_frm_array,
                examples=sampled_examples,
                n_examples=len(sampled_examples)  # used to create image feature copies.
            )
        else:
            raise RuntimeError(
             f"Failed to fetch video after {num_retries} retries.")

    def _get_single_example(self, data, index):
        examples = []

        text_str = data["txt"]
        itm_label = 1  # positive pair
        examples.append(dict(
            text_str=text_str,
            itm_label=itm_label
        ))
        return examples


class VideoRetrievalCollator(object):
    def __init__(self, tokenizer, max_length=40):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def collate_batch(self, batch):
        # FIXME there is a chance that two captions associated with the same video are batched together. Might need to fix.
        v_collate = default_collate
        visual_inputs = v_collate([d["vid"] for d in batch])  # (B, T, 3, H, W)
        # group data
        text_examples = flat_list_of_lists([d["examples"] for d in batch])
        n_examples_list = [d["n_examples"] for d in batch]  # (B, )
        # group elements data
        # directly concatenate question and option as a single seq.
        text_str_list = [d["text_str"] for d in text_examples]  # (B, )
        batch_enc = self.tokenizer.batch_encode_plus(
            text_str_list,
            max_length=self.max_length,
            padding='max_length',
            return_tensors="pt",
            truncation=True
        )
        text_input_ids = batch_enc.input_ids  # (B, L)
        text_input_mask = batch_enc.attention_mask  # (B, L)

        if "itm_label" in text_examples[0]:
            itm_labels = default_collate(
                [d["itm_label"] for d in text_examples])  # (B, )
        else:
            itm_labels = None

        if "id" in text_examples[0]:
            caption_ids = [d["id"] for d in text_examples]  # (B, )
        else:
            caption_ids = None
        collated_batch = dict(
            visual_inputs=visual_inputs,  # (B, #frm, H, W, C)
            text_input_ids=text_input_ids,
            text_input_mask=text_input_mask,
            caption_ids=caption_ids,  # list(int), example ids,
            labels=itm_labels,
            n_examples_list=n_examples_list  # used to create image feature copies.
        )
        if "vid_id" in batch[0] and len(batch) == 1:
            collated_batch["vid_id"] = batch[0]["vid_id"]
        return collated_batch


class AlproVideoRetrievalEvalDataset(AlproBaseDataset):
    """ Sample by video/image, calculate scores between each video with all the text
    and loop through all the videos. Each batch will only contain a single video,
    but multiple text.

    datalist: list(dict), each dict
    tokenizer:
    max_img_size: int,
    max_txt_len: int, max text sequence length, including special tokens.
    """
    def __init__(self, datalist, tokenizer, img_lmdb_dir,
                 fps=3, num_frm=3, frm_sampling_strategy="rand",
                 max_img_size=1000, max_txt_len=40, ensemble_n_clips=1,
                 video_fmt='.mp4', img_db_type='lmdb'):
        self.ensemble_n_clips = ensemble_n_clips
        super(AlproVideoRetrievalEvalDataset, self).__init__(
            datalist, tokenizer, img_lmdb_dir,
            fps=fps, num_frm=num_frm,
            frm_sampling_strategy=frm_sampling_strategy,
            max_img_size=max_img_size, max_txt_len=max_txt_len,
            img_db_type=img_db_type)
        # id is unique id per caption/example
        for i, d in enumerate(self.datalist):
            assert i == d["id"]
        self.gt_cap_id2vid_id = {d["id"]: d["vid_id"] for d in datalist}
        self.cap_id2data = {d["id"]: d for d in datalist}
        self.batches, self.text_batch = self._prepare_batches_by_video()
        self.id2data = {d["id"]: d for d in self.datalist}

        self.video_fmt = video_fmt

    def __len__(self):
        return len(self.batches)

    def __getitem__(self, index):
        # skip error videos:
        batch = dict()

        batch["vid_id"] = self.batches[index]["vid_id"]  # one video with multiple examples
        batch["examples"] = self.text_batch["examples"]
        batch["n_examples"] = self.text_batch["n_examples"]
        batch["ids"] = self.text_batch["ids"]

        if self.ensemble_n_clips > 1:
            raise NotImplementedError('Do not support multiple clips for now.')
        else:
            # if self.is_train and self.random_sample_clips:
            vid_id = batch["vid_id"]

            video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt) 
            vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)

        batch["vid"] = vid_frm_array
        return batch

    def _prepare_batches_by_video(self):
        """create batches where each batch contains a single video with multiple text"""
        text_list = []
        for d in self.datalist:
            text_list.append(dict(
                text_str=d["txt"],
                id=d["id"],
            ))
        text_batch = dict(
            vid_id=None,
            examples=text_list,
            n_examples=len(text_list),
            ids=[d["id"] for d in text_list]
        )

        # make 1000 batches for 1000video x 1000text combinations.
        # each batch contains 1video x 1000text
        batches = []
        for idx, d in enumerate(self.datalist):
             #_batch = copy.deepcopy(text_batch)
            _batch = dict()
            _batch["vid_id"] = d["vid_id"]
            batches.append(_batch)
        return batches, text_batch


================================================
FILE: src/datasets/randaugment.py
================================================
import cv2
import numpy as np
import torch


## aug functions
def identity_func(img):
    return img


def autocontrast_func(img, cutoff=0):
    '''
        same output as PIL.ImageOps.autocontrast
    '''
    n_bins = 256

    def tune_channel(ch):
        n = ch.size
        cut = cutoff * n // 100
        if cut == 0:
            high, low = ch.max(), ch.min()
        else:
            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
            low = np.argwhere(np.cumsum(hist) > cut)
            low = 0 if low.shape[0] == 0 else low[0]
            high = np.argwhere(np.cumsum(hist[::-1]) > cut)
            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
        if high <= low:
            table = np.arange(n_bins)
        else:
            scale = (n_bins - 1) / (high - low)
            offset = -low * scale
            table = np.arange(n_bins) * scale + offset
            table[table < 0] = 0
            table[table > n_bins - 1] = n_bins - 1
        table = table.clip(0, 255).astype(np.uint8)
        return table[ch]

    channels = [tune_channel(ch) for ch in cv2.split(img)]
    out = cv2.merge(channels)
    return out


def equalize_func(img):
    '''
        same output as PIL.ImageOps.equalize
        PIL's implementation is different from cv2.equalize
    '''
    n_bins = 256

    def tune_channel(ch):
        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
        non_zero_hist = hist[hist != 0].reshape(-1)
        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
        if step == 0: return ch
        n = np.empty_like(hist)
        n[0] = step // 2
        n[1:] = hist[:-1]
        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
        return table[ch]

    channels = [tune_channel(ch) for ch in cv2.split(img)]
    out = cv2.merge(channels)
    return out


def rotate_func(img, degree, fill=(0, 0, 0)):
    '''
    like PIL, rotate by degree, not radians
    '''
    H, W = img.shape[0], img.shape[1]
    center = W / 2, H / 2
    M = cv2.getRotationMatrix2D(center, degree, 1)
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
    return out


def horizontal_flip_func(img):
    '''
    [dxli]
    horizontally flip an image.
    '''
    out = cv2.flip(img, 1)

    return out


def solarize_func(img, thresh=128):
    '''
        same output as PIL.ImageOps.posterize
    '''
    table = np.array([el if el < thresh else 255 - el for el in range(256)])
    table = table.clip(0, 255).astype(np.uint8)
    out = table[img]
    return out


def color_func(img, factor):
    '''
        same output as PIL.ImageEnhance.Color
    '''
    ## implementation according to PIL definition, quite slow
    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
    #  out = blend(degenerate, img, factor)
    #  M = (
    #      np.eye(3) * factor
    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
    #  )[np.newaxis, np.newaxis, :]
    M = (
            np.float32([
                [0.886, -0.114, -0.114],
                [-0.587, 0.413, -0.587],
                [-0.299, -0.299, 0.701]]) * factor
            + np.float32([[0.114], [0.587], [0.299]])
    )
    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
    return out


def contrast_func(img, factor):
    """
        same output as PIL.ImageEnhance.Contrast
    """
    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
    table = np.array([(
        el - mean) * factor + mean
        for el in range(256)
    ]).clip(0, 255).astype(np.uint8)
    out = table[img]
    return out


def brightness_func(img, factor):
    '''
        same output as PIL.ImageEnhance.Contrast
    '''
    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
    out = table[img]
    return out


def sharpness_func(img, factor):
    '''
    The differences the this result and PIL are all on the 4 boundaries, the center
    areas are same
    '''
    kernel = np.ones((3, 3), dtype=np.float32)
    kernel[1][1] = 5
    kernel /= 13
    degenerate = cv2.filter2D(img, -1, kernel)
    if factor == 0.0:
        out = degenerate
    elif factor == 1.0:
        out = img
    else:
        out = img.astype(np.float32)
        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
        out = out.astype(np.uint8)
    return out


def shear_x_func(img, factor, fill=(0, 0, 0)):
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, factor, 0], [0, 1, 0]])
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
    return out


def translate_x_func(img, offset, fill=(0, 0, 0)):
    '''
        same output as PIL.Image.transform
    '''
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, 0, -offset], [0, 1, 0]])
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
    return out


def translate_y_func(img, offset, fill=(0, 0, 0)):
    '''
        same output as PIL.Image.transform
    '''
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, 0, 0], [0, 1, -offset]])
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
    return out


def posterize_func(img, bits):
    '''
        same output as PIL.ImageOps.posterize
    '''
    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
    return out


def shear_y_func(img, factor, fill=(0, 0, 0)):
    H, W = img.shape[0], img.shape[1]
    M = np.float32([[1, 0, 0], [factor, 1, 0]])
    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
    return out


# def cutout_func(img, pad_size, replace=(0, 0, 0)):
#     replace = np.array(replace, dtype=np.uint8)
#     H, W = img.shape[0], img.shape[1]
#     rh, rw = np.random.random(2)
#     pad_size = pad_size // 2
#     ch, cw = int(rh * H), int(rw * W)
#     x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
#     y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
#     out = img.copy()
#     out[x1:x2, y1:y2, :] = replace
#     return out


### level to args
def enhance_level_to_args(MAX_LEVEL):
    def level_to_args(level):
        return ((level / MAX_LEVEL) * 1.8 + 0.1,)
    return level_to_args


def shear_level_to_args(MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = (level / MAX_LEVEL) * 0.3
        # if np.random.random() > 0.5: level = -level
        return (level, replace_value)

    return level_to_args


def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = (level / MAX_LEVEL) * float(translate_const)
        # if np.random.random() > 0.5: level = -level
        return (level, replace_value)

    return level_to_args


def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = int((level / MAX_LEVEL) * cutout_const)
        return (level, replace_value)

    return level_to_args


def solarize_level_to_args(MAX_LEVEL):
    def level_to_args(level):
        level = int((level / MAX_LEVEL) * 256)
        return (level, )
    return level_to_args


def none_level_to_args(level):
    return ()


def posterize_level_to_args(MAX_LEVEL):
    def level_to_args(level):
        level = int((level / MAX_LEVEL) * 4)
        return (level, )
    return level_to_args


def rotate_level_to_args(MAX_LEVEL, replace_value):
    def level_to_args(level):
        level = (level / MAX_LEVEL) * 30
        # if np.random.random() < 0.5:
        #     level = -level
        return (level, replace_value)

    return level_to_args


func_dict = {
    'Identity': identity_func,
    # 'AutoContrast': autocontrast_func,
    'Equalize': equalize_func,
    'Rotate': rotate_func,
    'Solarize': solarize_func,
    'Color': color_func,
    'Contrast': contrast_func,
    'Brightness': brightness_func,
    'Sharpness': sharpness_func,
    'ShearX': shear_x_func,
    'TranslateX': translate_x_func,
    'TranslateY': translate_y_func,
    'Posterize': posterize_func,
    'ShearY': shear_y_func,
    'HorizontalFlip': horizontal_flip_func # [dxli]
}

translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
    'Identity': none_level_to_args,
    # 'AutoContrast': none_level_to_args,
    'Equalize': none_level_to_args,
    'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
    'Solarize': solarize_level_to_args(MAX_LEVEL),
    'Color': enhance_level_to_args(MAX_LEVEL),
    'Contrast': enhance_level_to_args(MAX_LEVEL),
    'Brightness': enhance_level_to_args(MAX_LEVEL),
    'Sharpness': enhance_level_to_args(MAX_LEVEL),
    'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
    'TranslateX': translate_level_to_args(
        translate_const, MAX_LEVEL, replace_value
    ),
    'TranslateY': translate_level_to_args(
        translate_const, MAX_LEVEL, replace_value
    ),
    'Posterize': posterize_level_to_args(MAX_LEVEL),
    'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
    'HorizontalFlip': none_level_to_args  # [dxli]
}


class TemporalConsistentRandomAugment(object):

    def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
        self.N = N
        self.M = M
        self.p = p
        self.tensor_in_tensor_out = tensor_in_tensor_out
        if augs:
            self.augs = augs       
        else:
            self.augs = list(arg_dict.keys())

    def get_random_ops(self):
        sampled_ops = np.random.choice(self.augs, self.N, replace=False)
        # return [(op, 0.5, self.M) for op in sampled_ops]
        return [(op, self.M) for op in sampled_ops]

    def __call__(self, frames):
        assert frames.shape[-1] == 3, 'Expecting last dimension for 3-channels RGB (b, h, w, c).'
        
        if self.tensor_in_tensor_out:
            frames = frames.numpy().astype(np.uint8)
        
        num_frames = frames.shape[0]

        ops = num_frames * [self.get_random_ops()]
        apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]

        frames = torch.stack(list(map(self._aug, frames, ops, apply_or_not)), dim=0).float()

        return frames

    def _aug(self, img, ops, apply_or_not):
        for i, (name, level) in enumerate(ops):
            if not apply_or_not[i]:
                continue
            args = arg_dict[name](level)
            img = func_dict[name](img, *args) 
        return torch.from_numpy(img)

class RandomAugment(object):

    def __init__(self, N=2, M=10, isPIL=False, augs=[]):
        self.N = N
        self.M = M
        self.isPIL = isPIL
        if augs:
            self.augs = augs       
        else:
            self.augs = list(arg_dict.keys())

    def get_random_ops(self):
        sampled_ops = np.random.choice(self.augs, self.N)
        return [(op, 0.5, self.M) for op in sampled_ops]

    def __call__(self, img):
        if self.isPIL:
            img = np.array(img)            
        ops = self.get_random_ops()
        for name, prob, level in ops:
            if np.random.random() > prob:
                continue
            args = arg_dict[name](level)
            img = func_dict[name](img, *args) 
        return img


def save_frames_grid(img_array, out_path):
    import torch
    from torchvision.utils import make_grid
    from PIL import Image

    if len(img_array.shape) == 3:
        img_array = img_array.unsqueeze(0)
    elif len(img_array.shape) == 5:
        b, t, c, h, w = img_array.shape
        img_array = img_array.view(-1, c, h, w)
    elif len(img_array.shape) == 4:
        pass
    else:
        raise NotImplementedError('Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored.')
    
    assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only."
    
    grid = make_grid(img_array)
    ndarr = grid.permute(1, 2, 0).to('cpu', torch.uint8).numpy()

    img = Image.fromarray(ndarr)

    img.save(out_path)


def stack(data, dim=0):
    shape = data[0].shape  # need to handle empty list
    shape = shape[:dim] + (len(data),) + shape[dim:]
    x = torch.cat(data, dim=dim)
    x = x.reshape(shape)
    # need to handle case where dim=-1
    # which is not handled here yet
    # but can be done with transposition
    return x


if __name__ == '__main__':
    import decord, os
    from decord import VideoReader
    decord.bridge.set_bridge('torch')

    root_dir = '/export/share/dongxuli/data/webvid2m/postprocess/downsampled_videos'
    video_id = '1058234725.mp4'

    video_path = os.path.join(root_dir, video_id) 
    vr = VideoReader(video_path)

    frames = vr.get_batch([1, 3, 5, 7, 9])
    frames = frames

    # a = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast', 'Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate'])     
    a = TemporalConsistentRandomAugment(N=1, M=5, augs=['HorizontalFlip'])

    print(frames[0].shape)
    save_frames_grid(frames.permute(0, 3, 1, 2), 'before.jpg')

    after_frames = a(frames)
    print(after_frames.shape)

    save_frames_grid(after_frames.permute(0, 3, 1, 2), 'after.jpg')

================================================
FILE: src/modeling/alpro_models.py
================================================
import copy

import numpy as np
import torch
import torch.nn.functional as F
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from einops import rearrange, reduce, repeat
from horovod import torch as hvd
from src.modeling.timesformer.vit import TimeSformer
from src.modeling.xbert import (BertEmbeddings, BertEncoder, BertForMaskedLM,
                                BertLMPredictionHead, BertModel, BertPooler,
                                BertPreTrainedModel, BertPreTrainingHeads)
from src.utils.basic_utils import load_json, load_jsonl, save_frames_grid
from src.utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss


class AlproBaseModel(nn.Module):
    def __init__(self, config=None, input_format='RGB', video_enc_cfg=None, temp=0.07):
        super().__init__()
        
        self.temp = nn.Parameter(torch.ones([]) * temp)   

        self.bert_config = config

        visual_model_cls = eval(video_enc_cfg['cls'])

        self.visual_encoder = visual_model_cls(model_cfg=video_enc_cfg, input_format=input_format, cross_attention_config=config)
        self.text_encoder = BertForMaskedLM.from_pretrained('bert-base-uncased', config=self.bert_config)

        # FIXME make them configurable
        embed_dim = 256
        vision_width = 768

        text_width = self.bert_config.hidden_size

        self.vision_proj = nn.Linear(vision_width, embed_dim)
        self.text_proj = nn.Linear(text_width, embed_dim)         

        self.itc_token_type = self.bert_config.itc_token_type
        self.itm_head = nn.Linear(text_width, 2)     


    def load_separate_ckpt(self, visual_weights_path=None, bert_weights_path=None):
        if visual_weights_path:
            self.visual_encoder.load_state_dict(visual_weights_path)

        # if bert_weights_path:
        #     load_multimodal_encoder_state_dict_with_mismatch(self.cross_encoder, bert_weights_path)
        #     load_mlm_head_state_dict_with_mismatch(self.mlm_head, bert_weights_path)

    # def freeze_cnn_backbone(self):
    #     for n, p in self.visual_encoder.feature.named_parameters():
    #         p.requires_grad = False


class AlproForPretrain(AlproBaseModel):
    def __init__(self, config, video_enc_cfg, input_format='RGB'):
        super(AlproForPretrain, self).__init__(config, input_format=input_format, video_enc_cfg=video_enc_cfg)

        # model for generating pseudo labels
        self.prompter = Prompter(config, video_enc_cfg)

        self.use_mask_prob = 0
        self.mpm_head = nn.Sequential(
            nn.Linear(config.hidden_size,
                    config.hidden_size * 2),
            nn.ReLU(True),
            nn.Linear(config.hidden_size * 2, self.prompter.entity_num)
        )

    def build_text_prompts(self, prompts):
        self.prompter.build_text_prompts(prompts)

    def get_pseudo_labels(self, batch):
        return self.prompter.get_pseudo_labels(batch)

    def forward(self, batch):
        with torch.no_grad():
            self.temp.clamp_(0.001,0.5)

        visual_inputs = batch['visual_inputs']

        use_mpm = 'mpm_mask' in batch
        if use_mpm:
            context_visual_inputs = batch['context_visual_inputs']

        device = visual_inputs.device
        b, t, c, h, w = visual_inputs.shape

        # forward image and text features
        # feats are normalized embeds
        if use_mpm and np.random.uniform() < self.use_mask_prob:
            video_embeds_total = self._forward_visual_embeds(torch.cat([visual_inputs, context_visual_inputs], dim=0))
            # split for unmasked and masked
            video_embeds, context_video_embeds = video_embeds_total[:b], video_embeds_total[b:]
        else:
            video_embeds = self._forward_visual_embeds(visual_inputs)
            context_video_embeds = video_embeds

        # we compute normalized feats for unmasked visual inputs only, used for ITC
        video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  
        video_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)
        
        # text embeddings and features
        text_embeds, text_feat = self._forward_text_feats(batch)

        # ========== (in-batch) ITC loss ==========
        gathered_video_feats = hvd.allgather(video_feat)
        gathered_text_feats = hvd.allgather(text_feat)

        assert self.itc_token_type == 'cls', 'Support CLS tokens for ITC only, find {}.'.format(self.itc_token_type)
        sim_v2t = video_feat @ gathered_text_feats.t() / self.temp 
        sim_t2v = text_feat @ gathered_video_feats.t() / self.temp 
                             
        # [IMPORTANT] be very careful when initializing the GT sim_v2t 
        # allgather return the concatenated features in the order of local_rank()
        sim_targets = torch.zeros_like(sim_v2t)

        local_rank = hvd.local_rank()
        b_start, b_end = b * local_rank, b * (local_rank + 1)
        sim_targets[:, b_start: b_end] = torch.eye(b)

        loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1)*sim_targets,dim=1).mean()
        loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1)*sim_targets,dim=1).mean() 

        vtc_loss = (loss_v2t+loss_t2v) / 2

        # ========= VTM ==========
        text_atts = batch['text_input_mask']

        # non-masked text and non-masked image 
        vtm_loss, vtm_logits, vtm_labels, encoder_outputs_pos = self.compute_vtm(text_embeds=text_embeds, 
                                                                                 text_atts=text_atts, 
                                                                                 video_embeds=video_embeds, 
                                                                                 video_atts=video_atts, 
                                                                                 sim_v2t=sim_v2t.clone(), # for hard mining
                                                                                 sim_t2v=sim_t2v.clone(), # for hard mining
                                                                                 return_encoder_out=True
                                                                                )

        # ========= MLM ========== 
        # masked text and non-masked image
        if 'mlm_labels' in batch: 
            mlm_labels = batch['mlm_labels']
            mlm_text_input_ids = batch['mlm_text_input_ids']

            mlm_loss, mlm_logits, mlm_labels = self.compute_mlm(input_ids=mlm_text_input_ids,
                                                                text_input_mask=text_atts,
                                                                video_embeds=video_embeds, 
                                                                video_atts=video_atts,
                                                                mlm_labels=mlm_labels
                                                                )
        else:
            mlm_logits = mlm_loss = mlm_labels = None

        # ========= MPM ========== 
        if use_mpm: 
            mpm_labels, ignore_masks = self.get_pseudo_labels(batch)

            mpm_loss, mpm_logits = self.compute_mpm_with_encoder_out(encoder_outputs=encoder_outputs_pos, 
                                                                     text_atts=text_atts, 
                                                                     soft_labels=mpm_labels, 
                                                                     ignore_masks=ignore_masks, 
                                                                     patch_masks=batch['mpm_mask']
                                                                    )

        else:
            mpm_loss = mpm_logits = mpm_labels =  None

        return dict(
            itc_loss=vtc_loss,
            mlm_scores=mlm_logits,  # (B, Lt, vocab_size),  only text part
            mlm_loss=mlm_loss,  # (BxLt)
            mlm_labels=mlm_labels,  # (1, Lt), with -100 indicates ignored positions
            itm_scores=vtm_logits,  # (B, 2)
            itm_loss=vtm_loss,  # (1, )
            itm_labels=vtm_labels,  # (B, )
            mpm_loss=mpm_loss,
            mpm_logits=mpm_logits,
            mpm_labels=mpm_labels
        )


    def _forward_visual_embeds(self, visual_inputs):
        b, t, c, h, w = visual_inputs.shape
        # timeSformer asks for (b, c, t, h, w) as input.
        # image features
        visual_inputs = visual_inputs.transpose(1, 2)

        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)

        return video_embeds

    def _forward_text_feats(self, batch):
        # text features
        text_output = self.text_encoder.bert(batch['text_input_ids'], 
                                             attention_mask=batch['text_input_mask'],                      
                                             return_dict = True, 
                                             mode = 'text'
                                            )

        text_embeds = text_output.last_hidden_state # b, Lt, fsz=768
        text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 

        return text_embeds, text_feat

    def compute_mpm_with_encoder_out(self, encoder_outputs, text_atts, soft_labels, ignore_masks, patch_masks):
        txt_len = text_atts.shape[1]
        # adding one to ignore visual cls tokens
        visual_output = encoder_outputs.last_hidden_state[:, txt_len+1:]

        bsz, h, w = patch_masks.shape
        patch_masks_flatten_inverted = (1 - patch_masks.view(bsz, -1)).unsqueeze(-1)

        # mean embeds of masked visual regions
        num_masked_patches = torch.sum(patch_masks_flatten_inverted.squeeze(-1), dim=-1, keepdim=True)

        masked_visual_embeds = patch_masks_flatten_inverted * visual_output
        masked_visual_embeds = torch.sum(masked_visual_embeds, dim=1)
        masked_visual_embeds /= num_masked_patches

        # loss
        mpm_logits = self.mpm_head(masked_visual_embeds)

        cross_entropy = -torch.sum(F.log_softmax(mpm_logits, dim=1) * soft_labels, dim=1)
        cross_entropy[ignore_masks] = 0.

        mpm_loss = torch.sum(cross_entropy) / (bsz - torch.sum(ignore_masks))

        return mpm_loss, mpm_logits 

    def compute_mpm(self, text_embeds, text_atts, image_embeds, image_atts, soft_labels, ignore_masks, patch_masks):
        # forward cross-encoder
        attention_mask = torch.cat([text_atts, image_atts], dim=1)
        embedding_output = torch.cat([text_embeds, image_embeds], dim=1)

        encoder_outputs = self.text_encoder.bert(encoder_embeds=embedding_output,
                                                 attention_mask=attention_mask,
                                                 return_dict=True,
                                                 mode='fusion'
                                                )

        txt_len = text_atts.shape[1]
        # adding one to ignore visual cls tokens
        visual_output = encoder_outputs.last_hidden_state[:, txt_len+1:]

        bsz, h, w = patch_masks.shape
        patch_masks_flatten_inverted = (1 - patch_masks.view(bsz, -1)).unsqueeze(-1)

        # mean embeds of masked visual regions
        num_masked_patches = torch.sum(patch_masks_flatten_inverted.squeeze(-1), dim=-1, keepdim=True)

        masked_visual_embeds = patch_masks_flatten_inverted * visual_output
        masked_visual_embeds = torch.sum(masked_visual_embeds, dim=1)
        masked_visual_embeds /= num_masked_patches

        # loss
        mpm_logits = self.mpm_head(masked_visual_embeds)

        cross_entropy = -torch.sum(F.log_softmax(mpm_logits, dim=1) * soft_labels, dim=1)
        cross_entropy[ignore_masks] = 0.

        mpm_loss = torch.sum(cross_entropy) / (bsz - torch.sum(ignore_masks))

        return mpm_loss, mpm_logits 

    def compute_vtm(self, text_embeds, text_atts, video_embeds, video_atts, sim_v2t, sim_t2v, return_encoder_out=False):
        device = text_embeds.device

        # ====== positive pairs =======
        attention_mask = torch.cat([text_atts, video_atts], dim=1)
        embedding_output_pos = torch.cat([text_embeds, video_embeds], dim=1)

        encoder_outputs_pos = self.text_encoder.bert(encoder_embeds=embedding_output_pos,
                                                     attention_mask=attention_mask,
                                                     return_dict=True,
                                                     mode='fusion'
                                                    )

        # ====== negative pairs =======
        bs = text_embeds.shape[0] 

        local_rank = hvd.local_rank()
        b_start, b_end = bs * local_rank, bs * (local_rank + 1)

        with torch.no_grad():
            weights_i2t = sim_v2t[:,b_start:b_end]
            weights_t2i = sim_t2v[:,b_start:b_end]
   
            # never select self as negative
            weights_i2t.fill_diagonal_(-np.Inf)
            weights_t2i.fill_diagonal_(-np.Inf)

            weights_i2t = F.softmax(weights_i2t, dim=1)
            weights_t2i = F.softmax(weights_t2i, dim=1)

        # select a negative image for each text
        # FIXME to optimize using indexing operations
        video_embeds_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_t2i[b], 1).item()
            video_embeds_neg.append(video_embeds[neg_idx])
        video_embeds_neg = torch.stack(video_embeds_neg,dim=0)   

        # select a negative text for each image
        text_embeds_neg = []
        text_atts_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_i2t[b], 1).item()
            text_embeds_neg.append(text_embeds[neg_idx])
            text_atts_neg.append(text_atts[neg_idx])

        text_embeds_neg = torch.stack(text_embeds_neg,dim=0)   
        text_atts_neg = torch.stack(text_atts_neg,dim=0)      

        text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0)     
        text_atts_all = torch.cat([text_atts, text_atts_neg],dim=0)     

        video_embeds_all = torch.cat([video_embeds_neg,video_embeds],dim=0)
        video_atts_all = torch.cat([video_atts,video_atts],dim=0)

        attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1)
        embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1)

        # forward negative pairs via cross encoder
        encoder_outputs_neg = self.text_encoder.bert(encoder_embeds=embedding_output_all,
                                                     attention_mask=attention_mask_all,
                                                     return_dict=True,
                                                     mode='fusion'
                                                    )

        vl_embeddings = torch.cat([encoder_outputs_pos.last_hidden_state[:,0,:], 
                                   encoder_outputs_neg.last_hidden_state[:,0,:]],dim=0)
        vtm_logits = self.itm_head(vl_embeddings)            

        vtm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], dim=0).to(device)
        vtm_loss = F.cross_entropy(vtm_logits, vtm_labels)     

        if return_encoder_out:
            return vtm_loss, vtm_logits, vtm_labels, encoder_outputs_pos 
        else:
            return vtm_loss, vtm_logits, vtm_labels, None
        
    def compute_mlm(self, input_ids, text_input_mask, video_embeds, video_atts, mlm_labels):
        # forward text features with masked_input_ids
        text_output = self.text_encoder.bert(input_ids,
                                             attention_mask=text_input_mask,
                                             return_dict=True,
                                             mode='text'
                                            )
        text_embeds = text_output.last_hidden_state

        # forward cross-encoder
        attention_mask = torch.cat([text_input_mask, video_atts], dim=1)
        embedding_output = torch.cat([text_embeds, video_embeds], dim=1)

        encoder_outputs = self.text_encoder.bert(encoder_embeds=embedding_output,
                                                 attention_mask=attention_mask,
                                                 return_dict=True,
                                                 mode='fusion'
                                                )

        txt_len = text_input_mask.shape[1]
        txt_output = encoder_outputs.last_hidden_state[:, :txt_len]

        mlm_logits = self.text_encoder.cls(txt_output)

        loss_fct = CrossEntropyLoss()
        mlm_loss = loss_fct(mlm_logits.view(-1, self.bert_config.vocab_size), mlm_labels.view(-1))

        return mlm_loss, mlm_logits, mlm_labels
        
    def load_separate_ckpt(self, visual_weights_path=None, bert_weights_path=None, prompter_weights_path=None):
        if visual_weights_path:
            self.visual_encoder.load_state_dict(visual_weights_path)

        # [NOTE] BERT is initialized from huggingface pre-trained weights. 
        # if bert_weights_path:
        #     load_multimodal_encoder_state_dict_with_mismatch(self.cross_encoder, bert_weights_path)
        #     load_mlm_head_state_dict_with_mismatch(self.mlm_head, bert_weights_path)

        # TODO make path configurable
        if prompter_weights_path is not None:
            self.prompter.load_pretrained_weights_without_prompts(prompter_weights_path)


class Prompter(AlproBaseModel):
    def __init__(self, config, video_enc_cfg, input_format='RGB'):
        super(Prompter, self).__init__(config, input_format=input_format, video_enc_cfg=video_enc_cfg)

        # self.entity_num = 1000
        self.entity_num = config.num_entities

        self.register_buffer("video_prompt_feat", torch.rand(self.entity_num, 256)) 
        self.register_buffer("image_prompt_feat", torch.rand(self.entity_num, 256)) 

        self.prompt_initialized = False
        # if the prob for the most likely entity is < 0.2, we just ignore it
        self.ignore_threshold = 0.2


    def load_pretrained_weights_without_prompts(self, ckpt_path):
        LOGGER.info("Loading weights for teacher model.")
        loaded_state_dict = torch.load(ckpt_path, map_location='cpu')

        loaded_keys = loaded_state_dict.keys()
        model_keys = self.state_dict().keys()

        load_not_in_model = [k for k in loaded_keys if k not in model_keys]
        model_not_in_load = [k for k in model_keys if k not in loaded_keys]

        if hvd.rank() == 0:
            LOGGER.info("Keys in loaded but not in model:")
            LOGGER.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}")
            LOGGER.info("Keys in model but not in loaded:")
            LOGGER.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}")

        # FIXME a quick hack to avoid loading prompts
        new_loaded_state_dict = dict()
        for k in loaded_state_dict:
            if not 'prompt_feat' in k:
                new_loaded_state_dict[k] = loaded_state_dict[k]

        loaded_state_dict = new_loaded_state_dict

        self.load_state_dict(loaded_state_dict, strict=False)

    def build_text_prompts(self, prompts):
        """
        This function will be called, if no e2e.weights is provided.
        In that case, 
        """
        assert not self.prompt_initialized, "Repetitively building prompts?"

        if self.training:
            self.eval()

        video_prompt_feat_all = []
        image_prompt_feat_all = []

        with torch.no_grad():
            # this configurable depending on the GPU memory limit
            step_size = 10000

            # ====== initializing video prompting ======
            b_video, _ = prompts['batch_enc_video_prompts'].input_ids.shape

            start = 0
            end = start + step_size

            while start < b_video:
                video_prompt_output = self.text_encoder.bert(prompts['batch_enc_video_prompts'].input_ids[start:end].cuda(), 
                                                            attention_mask=prompts['batch_enc_video_prompts'].attention_mask[start:end].cuda(),                      
                                                            return_dict=True, 
                                                            mode='text'
                                                            )

                video_prompt_embeds = video_prompt_output.last_hidden_state # b, Lt, fsz=768
                video_prompt_feat = F.normalize(self.text_proj(video_prompt_embeds[:,0,:]),dim=-1)                 

                # collecting
                video_prompt_feat_all.append(video_prompt_feat)
            
                start += step_size
                end += step_size

            # average ensembling
            video_prompt_feat = torch.cat(video_prompt_feat_all, dim=0)
            video_num_templates = int(video_prompt_feat.shape[0] / self.entity_num)

            video_prompt_feat = torch.stack(video_prompt_feat.chunk(video_num_templates), dim=1)
            video_prompt_feat = torch.mean(video_prompt_feat, dim=1)
            self.video_prompt_feat = video_prompt_feat

            # ====== initializing image prompting ======
            b_image, _ = prompts['batch_enc_image_prompts'].input_ids.shape

            start = 0
            end = start + step_size

            while start < b_image:
                # image prompts
                image_prompt_output = self.text_encoder.bert(prompts['batch_enc_image_prompts'].input_ids[start:end].cuda(), 
                                                            attention_mask=prompts['batch_enc_image_prompts'].attention_mask[start:end].cuda(),                      
                                                            return_dict = True, 
                                                            mode = 'text'
                                                            )

                image_prompt_embeds = image_prompt_output.last_hidden_state # b, Lt, fsz=768
                image_prompt_feat = F.normalize(self.text_proj(image_prompt_embeds[:,0,:]),dim=-1)                 

                # collecting
                image_prompt_feat_all.append(image_prompt_feat)

                start += step_size
                end += step_size

            image_prompt_feat = torch.cat(image_prompt_feat_all, dim=0)
            image_num_templates = int(image_prompt_feat.shape[0] / self.entity_num)

            image_prompt_feat = torch.stack(image_prompt_feat.chunk(image_num_templates), dim=1)
            image_prompt_feat = torch.mean(image_prompt_feat, dim=1)
            self.image_prompt_feat = image_prompt_feat

        self.prompt_initialized = True

    def _forward_visual_embeds(self, visual_inputs):
        b, t, c, h, w = visual_inputs.shape
        # timeSformer asks for (b, c, t, h, w) as input.
        # image features
        visual_inputs = visual_inputs.transpose(1, 2)

        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)

        assert self.itc_token_type == 'cls', 'Expecting CLS token for ITC, found {}'.format(self.itc_token_type)
        if self.itc_token_type == 'cls':
            video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  
        else:
            raise NotImplementedError("itc_type_type must be one of ['mean', 'cls', 'mil'], found {}".format(self.itc_token_type))
        
        return video_embeds, video_feat

    def _compute_soft_labels(self, sim_vp_masked):
        soft_labels = nn.Softmax(dim=1)(sim_vp_masked)
        ignore_masks = torch.max(sim_vp_masked, dim=1)[1] < self.ignore_threshold

        return soft_labels, ignore_masks

    def get_pseudo_labels(self, batch):
        if self.training:
            self.eval()

        with torch.no_grad():
            masked_visual_inputs = batch['crop_visual_inputs']

            _, masked_image_feat = self._forward_visual_embeds(masked_visual_inputs)

            if batch['type'] == 'video':
                prompt_feat = self.video_prompt_feat
            else:
                prompt_feat = self.image_prompt_feat

            # visual feat to video prompts
            # masked visual feat to video prompts
            sim_masked = masked_image_feat @ prompt_feat.t() / self.temp 

            pseudo_labels, ignore_masks = self._compute_soft_labels(sim_masked)

        return pseudo_labels, ignore_masks

    def forward(self, batch):
        visual_inputs = batch['visual_inputs']

        device = visual_inputs.device
        b, t, c, h, w = visual_inputs.shape

        # forward image and text features
        # feats are normalized embeds
        video_embeds, video_feat, text_embeds, text_feat = self.forward_feats(batch)
        image_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)

        # ========== (in-batch) ITC loss ==========
        gathered_image_feats = hvd.allgather(video_feat)
        gathered_text_feats = hvd.allgather(text_feat)

        assert self.itc_token_type == 'cls', 'Expecting CLS token for ITC, found {}'.format(self.itc_token_type)

        sim_v2t = video_feat @ gathered_text_feats.t() / self.temp 
        sim_t2v = text_feat @ gathered_image_feats.t() / self.temp 
                             
        # [IMPORTANT] be very careful when initializing the GT sim_i2t 
        # allgather return the concatenated features in the order of local_rank()
        sim_targets = torch.zeros_like(sim_v2t)

        local_rank = hvd.local_rank()
        b_start, b_end = b * local_rank, b * (local_rank + 1)
        sim_targets[:, b_start: b_end] = torch.eye(b)

        sim_v2t_scores = F.log_softmax(sim_v2t, dim=1)
        sim_t2v_scores = F.log_softmax(sim_t2v, dim=1)

        loss_v2t = -torch.sum(sim_v2t_scores * sim_targets,dim=1).mean()
        loss_t2v = -torch.sum(sim_t2v_scores * sim_targets,dim=1).mean() 

        vtc_loss = (loss_v2t+loss_t2v) / 2

        return dict(
            itc_loss=vtc_loss,
            itc_labels=torch.max(sim_targets, dim=1)[1],
            i2t_scores=sim_v2t_scores,
            t2i_scores=sim_t2v_scores
        )


    def forward_feats(self, batch):
        with torch.no_grad():
            self.temp.clamp_(0.001,0.5)

        visual_inputs = batch['visual_inputs']

        b, t, c, h, w = visual_inputs.shape
        # timeSformer asks for (b, c, t, h, w) as input.
        # image features
        visual_inputs = visual_inputs.transpose(1, 2)

        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)

        assert self.itc_token_type == 'cls', 'Expecting CLS token for ITC, found {}'.format(self.itc_token_type)
        if self.itc_token_type == 'cls':
            video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  
        else:
            raise NotImplementedError("itc_type_type must be one of ['mean', 'cls', 'mil'], found {}".format(self.itc_token_type))

        # text features
        text_output = self.text_encoder.bert(batch['text_input_ids'], 
                                             attention_mask=batch['text_input_mask'],                      
                                             return_dict = True, 
                                             mode = 'text'
                                            )

        text_embeds = text_output.last_hidden_state # b, Lt, fsz=768

        if self.itc_token_type == 'cls':
            text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 
        else:
            raise NotImplementedError("itc_token_type must be one of ['mean', 'cls', 'mil'], found {}".format(self.itc_token_type))

        return video_embeds, video_feat, text_embeds, text_feat


class AlproForSequenceClassification(AlproBaseModel):
    def __init__(self, config, video_enc_cfg, input_format='RGB'):
        super(AlproForSequenceClassification, self).__init__(config, video_enc_cfg=video_enc_cfg)

        self.text_encoder = BertModel.from_pretrained('bert-base-uncased', config=self.bert_config, add_pooling_layer=False)      

        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size,
                      config.hidden_size * 2),
            nn.ReLU(True),
            nn.Linear(config.hidden_size * 2, config.num_labels)
        )

    # def forward(self, image, text, targets, alpha=0, train=True):
    def forward(self, batch):
        visual_inputs = batch['visual_inputs']
        targets = batch['labels']

        device = visual_inputs.device

        # forward text
        text_input_mask = batch['text_input_mask']
        text_output = self.text_encoder(batch['text_input_ids'],
                                        attention_mask=text_input_mask,
                                        return_dict=True,
                                        mode='text'
                                        )
        text_embeds = text_output.last_hidden_state

        # forward visual
        b, t, c, h, w = visual_inputs.shape
        # timeSformer asks for (b, c, t, h, w) as input.
        visual_inputs = visual_inputs.transpose(1, 2)

        image_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(device)

        # forward cross-encoder
        attention_mask = torch.cat([text_input_mask, image_atts], dim=1)
        embedding_output = torch.cat([text_embeds, image_embeds], dim=1)

        output = self.text_encoder(encoder_embeds=embedding_output,
                                attention_mask=attention_mask,
                                return_dict=True,
                                mode='fusion'
                                )

        prediction = self.classifier(output.last_hidden_state[:,0,:])                
        if targets is not None:
            loss = F.cross_entropy(prediction, targets)                
        else: # evaluation mode
            loss = 0

        return dict(loss=loss,
                    logits=prediction
                    )
            

    def forward_inference(self, batch):
        visual_inputs = batch['visual_inputs']
        device = visual_inputs.device

        # forward text
        text_input_mask = batch['text_input_mask']
        text_output = self.text_encoder.bert(batch['text_input_ids'],
                                             attention_mask=text_input_mask,
                                             return_dict=True,
                                             mode='text'
                                            )
        text_embeds = text_output.last_hidden_state

        # forward visual
        b, t, c, h, w = visual_inputs.shape
        # timeSformer asks for (b, c, t, h, w) as input.
        visual_inputs = visual_inputs.transpose(1, 2)

        image_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(device)

        # forward cross-encoder
        attention_mask = torch.cat([text_input_mask, image_atts], dim=1)
        embedding_output = torch.cat([text_embeds, image_embeds], dim=1)

        output = self.text_encoder.bert(encoder_embeds=embedding_output,
                                        attention_mask=attention_mask,
                                        return_dict=True,
                                        mode='fusion'
                                    )

        prediction = self.classifier(output.last_hidden_state[:,0,:])                

        return prediction


class AlproForVideoTextRetrieval(AlproBaseModel):
    """
    """
    def __init__(self, config, video_enc_cfg, input_format='RGB'):
        super(AlproForVideoTextRetrieval, self).__init__(config, input_format=input_format, video_enc_cfg=video_enc_cfg)

    def forward(self, batch):
        with torch.no_grad():
            self.temp.clamp_(0.001,0.5)

        visual_inputs = batch['visual_inputs']
        text_input_mask = batch['text_input_mask']
        text_input_ids = batch['text_input_ids']

        device = visual_inputs.device

        b, t, c, h, w = visual_inputs.shape
        # timeSformer asks for (b, c, t, h, w) as input.
        # visual embeddings
        visual_inputs = visual_inputs.transpose(1, 2)

        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
        # image_embeds = image_embeds.repeat(text_input_mask.shape[0], 1, 1)
        video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  

        video_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)

        # text embeddings
        text_output = self.text_encoder.bert(text_input_ids,
                                             attention_mask=text_input_mask,
                                             return_dict=True,
                                             mode='text'
                                            )
        text_embeds = text_output.last_hidden_state
        text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 

        # ========== (in-batch) ITC loss ==========
        gathered_video_feats = hvd.allgather(video_feat)
        gathered_text_feats = hvd.allgather(text_feat)

        sim_v2t = video_feat @ gathered_text_feats.t() / self.temp 
        sim_t2v = text_feat @ gathered_video_feats.t() / self.temp 

        sim_targets = torch.zeros_like(sim_v2t)

        local_rank = hvd.local_rank()
        b_start, b_end = b * local_rank, b * (local_rank + 1)
        sim_targets[:, b_start: b_end] = torch.eye(b)

        loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1)*sim_targets,dim=1).mean()
        loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1)*sim_targets,dim=1).mean() 

        vtc_loss = (loss_v2t+loss_t2v) / 2

        # ========= ITM ==========
        text_atts = batch['text_input_mask']

        # non-masked text and non-masked image 
        vtm_loss, vtm_logits, vtm_labels = self.compute_vtm(text_embeds=text_embeds, 
                                                            text_atts=text_atts, 
                                                            image_embeds=video_embeds, 
                                                            image_atts=video_atts, 
                                                            sim_i2t=sim_v2t.clone(), # for hard mining
                                                            sim_t2i=sim_t2v.clone()  # for hard mining
                                                           )

        return dict(
            itm_scores=vtm_logits,
            itm_loss=vtm_loss,
            itm_labels=vtm_labels,
            itc_loss=vtc_loss
        )
    
    def compute_vtm(self, text_embeds, text_atts, image_embeds, image_atts, sim_i2t, sim_t2i):
        device = text_embeds.device

        # ====== positive pairs =======
        attention_mask = torch.cat([text_atts, image_atts], dim=1)
        embedding_output_pos = torch.cat([text_embeds, image_embeds], dim=1)

        encoder_outputs_pos = self.text_encoder.bert(encoder_embeds=embedding_output_pos,
                                                     attention_mask=attention_mask,
                                                     return_dict=True,
                                                     mode='fusion'
                                                    )

        # ====== negative pairs =======
        bs = text_embeds.shape[0] 

        local_rank = hvd.local_rank()
        b_start, b_end = bs * local_rank, bs * (local_rank + 1)

        with torch.no_grad():
            weights_v2t = sim_i2t[:,b_start:b_end]
            weights_t2v = sim_t2i[:,b_start:b_end]
   
            # never select self as negative
            weights_v2t.fill_diagonal_(-np.Inf)
            weights_t2v.fill_diagonal_(-np.Inf)

            weights_v2t = F.softmax(weights_v2t, dim=1)
            weights_t2v = F.softmax(weights_t2v, dim=1)

        # select a negative image for each text
        # FIXME to optimize using indexing operations
        image_embeds_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_t2v[b], 1).item()
            image_embeds_neg.append(image_embeds[neg_idx])
        image_embeds_neg = torch.stack(image_embeds_neg,dim=0)   

        # select a negative text for each image
        text_embeds_neg = []
        text_atts_neg = []
        for b in range(bs):
            neg_idx = torch.multinomial(weights_v2t[b], 1).item()
            text_embeds_neg.append(text_embeds[neg_idx])
            text_atts_neg.append(text_atts[neg_idx])

        text_embeds_neg = torch.stack(text_embeds_neg,dim=0)   
        text_atts_neg = torch.stack(text_atts_neg,dim=0)      

        text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0)     
        text_atts_all = torch.cat([text_atts, text_atts_neg],dim=0)     

        video_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
        video_atts_all = torch.cat([image_atts,image_atts],dim=0)

        attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1)
        embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1)

        # forward negative pairs via cross encoder
        encoder_outputs_neg = self.text_encoder.bert(encoder_embeds=embedding_output_all,
                                                     attention_mask=attention_mask_all,
                                                     return_dict=True,
                                                     mode='fusion'
                                                    )

        vl_embeddings = torch.cat([encoder_outputs_pos.last_hidden_state[:,0,:], 
                                   encoder_outputs_neg.last_hidden_state[:,0,:]],dim=0)
        vtm_logits = self.itm_head(vl_embeddings)            

        vtm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], dim=0).to(device)
        vtm_loss = F.cross_entropy(vtm_logits, vtm_labels)     

        return vtm_loss, vtm_logits, vtm_labels 

    def forward_inference(self, batch):
        visual_inputs = batch['visual_inputs']
        text_input_mask = batch['text_input_mask']
        text_input_ids = batch['text_input_ids']

        device = visual_inputs.device

        b, t, c, h, w = visual_inputs.shape
        # timeSformer asks for (b, c, t, h, w) as input.
        visual_inputs = visual_inputs.transpose(1, 2)

        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
        video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  

        video_embeds = video_embeds.repeat(text_input_mask.shape[0], 1, 1)
        # image_feat = image_feat.repeat(text_input_mask.shape[0], 1)

        video_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)
        text_output = self.text_encoder.bert(text_input_ids,
                                             attention_mask=text_input_mask,
                                             return_dict=True,
                                             mode='text'
                                            )
        text_embeds = text_output.last_hidden_state
        text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 

        vtc_sim_scores = video_feat @ text_feat.t() / self.temp

        attention_mask = torch.cat([text_input_mask, video_atts], dim=1)
        embedding_output = torch.cat([text_embeds, video_embeds], dim=1)

        encoder_outputs = self.text_encoder.bert(encoder_embeds=embedding_output,
                                                 attention_mask=attention_mask,
                                                 return_dict=True,
                                                 mode='fusion'
                                                )

        vl_embeddings = encoder_outputs.last_hidden_state[:,0,:]
        logits = self.itm_head(vl_embeddings)

        return dict(logits=logits, itc_scores=vtc_sim_scores)



================================================
FILE: src/modeling/timesformer/__init__.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

# from .build import MODEL_REGISTRY, build_model  # noqa
# from .custom_video_model_builder import *  # noqa
# from .video_model_builder import ResNet, SlowFast # noqa


================================================
FILE: src/modeling/timesformer/conv2d_same.py
================================================
# Copyright 2020 Ross Wightman
# Conv2d w/ Same Padding

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional

import math
from typing import List, Tuple
#from .padding import pad_same, get_padding_value

# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
    ih, iw = x.size()[-2:]
    pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
    return x

# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
    return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)

def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
    dynamic = False
    if isinstance(padding, str):
        # for any string padding, the padding will be calculated for you, one of three ways
        padding = padding.lower()
        if padding == 'same':
            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
            if is_static_pad(kernel_size, **kwargs):
                # static case, no extra overhead
                padding = get_padding(kernel_size, **kwargs)
            else:
                # dynamic 'SAME' padding, has runtime/GPU memory overhead
                padding = 0
                dynamic = True
        elif padding == 'valid':
            # 'VALID' padding, same as padding=0
            padding = 0
        else:
            # Default to PyTorch style 'same'-ish symmetric padding
            padding = get_padding(kernel_size, **kwargs)
    return padding, dynamic

def conv2d_same(
        x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
        padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
    x = pad_same(x, weight.shape[-2:], stride, dilation)
    return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)


class Conv2dSame(nn.Conv2d):
    """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2dSame, self).__init__(
            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)

    def forward(self, x):
        return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
    padding = kwargs.pop('padding', '')
    kwargs.setdefault('bias', False)
    padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
    if is_dynamic:
        return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
    else:
        return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)


================================================
FILE: src/modeling/timesformer/features.py
================================================
# Copyright 2020 Ross Wightman

from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Tuple

import torch
import torch.nn as nn


class FeatureInfo:

    def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
        prev_reduction = 1
        for fi in feature_info:
            # sanity check the mandatory fields, there may be additional fields depending on the model
            assert 'num_chs' in fi and fi['num_chs'] > 0
            assert 'reduction' in fi and fi['reduction'] >= prev_reduction
            prev_reduction = fi['reduction']
            assert 'module' in fi
        self.out_indices = out_indices
        self.info = feature_info

    def from_other(self, out_indices: Tuple[int]):
        return FeatureInfo(deepcopy(self.info), out_indices)

    def get(self, key, idx=None):
        """ Get value by key at specified index (indices)
        if idx == None, returns value for key at each output index
        if idx is an integer, return value for that feature module index (ignoring output indices)
        if idx is a list/tupple, return value for each module index (ignoring output indices)
        """
        if idx is None:
            return [self.info[i][key] for i in self.out_indices]
        if isinstance(idx, (tuple, list)):
            return [self.info[i][key] for i in idx]
        else:
            return self.info[idx][key]

    def get_dicts(self, keys=None, idx=None):
        """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
        """
        if idx is None:
            if keys is None:
                return [self.info[i] for i in self.out_indices]
            else:
                return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
        if isinstance(idx, (tuple, list)):
            return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
        else:
            return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}

    def channels(self, idx=None):
        """ feature channels accessor
        """
        return self.get('num_chs', idx)

    def reduction(self, idx=None):
        """ feature reduction (output stride) accessor
        """
        return self.get('reduction', idx)

    def module_name(self, idx=None):
        """ feature module name accessor
        """
        return self.get('module', idx)

    def __getitem__(self, item):
        return self.info[item]

    def __len__(self):
        return len(self.info)


class FeatureHooks:
    """ Feature Hook Helper
    This module helps with the setup and extraction of hooks for extracting features from
    internal nodes in a model by node name. This works quite well in eager Python but needs
    redesign for torcscript.
    """

    def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
        # setup feature hooks
        modules = {k: v for k, v in named_modules}
        for i, h in enumerate(hooks):
            hook_name = h['module']
            m = modules[hook_name]
            hook_id = out_map[i] if out_map else hook_name
            hook_fn = partial(self._collect_output_hook, hook_id)
            hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
            if hook_type == 'forward_pre':
                m.register_forward_pre_hook(hook_fn)
            elif hook_type == 'forward':
                m.register_forward_hook(hook_fn)
            else:
                assert False, "Unsupported hook type"
        self._feature_outputs = defaultdict(OrderedDict)

    def _collect_output_hook(self, hook_id, *args):
        x = args[-1]  # tensor we want is last argument, output for fwd, input for fwd_pre
        if isinstance(x, tuple):
            x = x[0]  # unwrap input tuple
        self._feature_outputs[x.device][hook_id] = x

    def get_output(self, device) -> Dict[str, torch.tensor]:
        output = self._feature_outputs[device]
        self._feature_outputs[device] = OrderedDict()  # clear after reading
        return output


def _module_list(module, flatten_sequential=False):
    # a yield/iter would be better for this but wouldn't be compatible with torchscript
    ml = []
    for name, module in module.named_children():
        if flatten_sequential and isinstance(module, nn.Sequential):
            # first level of Sequential containers is flattened into containing model
            for child_name, child_module in module.named_children():
                combined = [name, child_name]
                ml.append(('_'.join(combined), '.'.join(combined), child_module))
        else:
            ml.append((name, name, module))
    return ml


def _get_feature_info(net, out_indices):
    feature_info = getattr(net, 'feature_info')
    if isinstance(feature_info, FeatureInfo):
        return feature_info.from_other(out_indices)
    elif isinstance(feature_info, (list, tuple)):
        return FeatureInfo(net.feature_info, out_indices)
    else:
        assert False, "Provided feature_info is not valid"


def _get_return_layers(feature_info, out_map):
    module_names = feature_info.module_name()
    return_layers = {}
    for i, name in enumerate(module_names):
        return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
    return return_layers


class FeatureDictNet(nn.ModuleDict):
    """ Feature extractor with OrderedDict return
    Wrap a model and extract features as specified by the out indices, the network is
    partially re-built from contained modules.
    There is a strong assumption that the modules have been registered into the model in the same
    order as they are used. There should be no reuse of the same nn.Module more than once, including
    trivial modules like `self.relu = nn.ReLU`.
    Only submodules that are directly assigned to the model class (`model.feature1`) or at most
    one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
    All Sequential containers that are directly assigned to the original model will have their
    modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
    Arguments:
        model (nn.Module): model from which we will extract the features
        out_indices (tuple[int]): model output indices to extract features for
        out_map (sequence): list or tuple specifying desired return id for each out index,
            otherwise str(index) is used
        feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
            vs select element [0]
        flatten_sequential (bool): whether to flatten sequential modules assigned to model
    """
    def __init__(
            self, model,
            out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
        super(FeatureDictNet, self).__init__()
        self.feature_info = _get_feature_info(model, out_indices)
        self.concat = feature_concat
        self.return_layers = {}
        return_layers = _get_return_layers(self.feature_info, out_map)
        modules = _module_list(model, flatten_sequential=flatten_sequential)
        remaining = set(return_layers.keys())
        layers = OrderedDict()
        for new_name, old_name, module in modules:
            layers[new_name] = module
            if old_name in remaining:
                # return id has to be consistently str type for torchscript
                self.return_layers[new_name] = str(return_layers[old_name])
                remaining.remove(old_name)
            if not remaining:
                break
        assert not remaining and len(self.return_layers) == len(return_layers), \
            f'Return layers ({remaining}) are not present in model'
        self.update(layers)

    def _collect(self, x) -> (Dict[str, torch.Tensor]):
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                out_id = self.return_layers[name]
                if isinstance(x, (tuple, list)):
                    # If model tap is a tuple or list, concat or select first element
                    # FIXME this may need to be more generic / flexible for some nets
                    out[out_id] = torch.cat(x, 1) if self.concat else x[0]
                else:
                    out[out_id] = x
        return out

    def forward(self, x) -> Dict[str, torch.Tensor]:
        return self._collect(x)


class FeatureListNet(FeatureDictNet):
    """ Feature extractor with list return
    See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
    In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
    """
    def __init__(
            self, model,
            out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
        super(FeatureListNet, self).__init__(
            model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
            flatten_sequential=flatten_sequential)

    def forward(self, x) -> (List[torch.Tensor]):
        return list(self._collect(x).values())


class FeatureHookNet(nn.ModuleDict):
    """ FeatureHookNet
    Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
    If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
    network in any way.
    If `no_rewrite` is False, the model will be re-written as in the
    FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
    FIXME this does not currently work with Torchscript, see FeatureHooks class
    """
    def __init__(
            self, model,
            out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
            feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
        super(FeatureHookNet, self).__init__()
        assert not torch.jit.is_scripting()
        self.feature_info = _get_feature_info(model, out_indices)
        self.out_as_dict = out_as_dict
        layers = OrderedDict()
        hooks = []
        if no_rewrite:
            assert not flatten_sequential
            if hasattr(model, 'reset_classifier'):  # make sure classifier is removed?
                model.reset_classifier(0)
            layers['body'] = model
            hooks.extend(self.feature_info.get_dicts())
        else:
            modules = _module_list(model, flatten_sequential=flatten_sequential)
            remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
                         for f in self.feature_info.get_dicts()}
            for new_name, old_name, module in modules:
                layers[new_name] = module
                for fn, fm in module.named_modules(prefix=old_name):
                    if fn in remaining:
                        hooks.append(dict(module=fn, hook_type=remaining[fn]))
                        del remaining[fn]
                if not remaining:
                    break
            assert not remaining, f'Return layers ({remaining}) are not present in model'
        self.update(layers)
        self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)

    def forward(self, x):
        for name, module in self.items():
            x = module(x)
        out = self.hooks.get_output(x.device)
        return out if self.out_as_dict else list(out.values())


================================================
FILE: src/modeling/timesformer/helpers.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright 2020 Ross Wightman
# Modified model creation / weight loading / state_dict helpers

import logging
import os
import sys
import math
from collections import OrderedDict
from copy import deepcopy
from typing import Callable

import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F

from src.modeling.timesformer.features import FeatureListNet, FeatureDictNet, FeatureHookNet
from src.modeling.timesformer.conv2d_same import Conv2dSame
from src.modeling.timesformer.linear import Linear

from horovod import torch as hvd

_logger = logging.getLogger()

def load_state_dict(checkpoint_path, use_ema=False):
    if checkpoint_path and os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        state_dict_key = 'state_dict'
        if isinstance(checkpoint, dict):
            if use_ema and 'state_dict_ema' in checkpoint:
                state_dict_key = 'state_dict_ema'
        if state_dict_key and state_dict_key in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint[state_dict_key].items():
                # strip `module.` prefix
                name = k[7:] if k.startswith('module') else k
                new_state_dict[name] = v
            state_dict = new_state_dict
        elif 'model_state' in checkpoint:
            state_dict_key = 'model_state'
            new_state_dict = OrderedDict()
            for k, v in checkpoint[state_dict_key].items():
                # strip `model.` prefix
                name = k[6:] if k.startswith('model') else k
                new_state_dict[name] = v
            state_dict = new_state_dict
        else:
            state_dict = checkpoint
        _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
        return state_dict
    else:
        _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
        raise FileNotFoundError()


def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
    state_dict = load_state_dict(checkpoint_path, use_ema)
    model.load_state_dict(state_dict, strict=strict)


# def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
#     resume_epoch = None
    # if os.path.isfile(checkpoint_path):
    #     checkpoint = torch.load(checkpoint_path, map_location='cpu')
    #     if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
    #         if log_info:
    #             _logger.info('Restoring model state from checkpoint...')
    #         new_state_dict = OrderedDict()
    #         for k, v in checkpoint['state_dict'].items():
    #             name = k[7:] if k.startswith('module') else k
    #             new_state_dict[name] = v
    #         model.load_state_dict(new_state_dict)

    #         if optimizer is not None and 'optimizer' in checkpoint:
    #             if log_info:
    #                 _logger.info('Restoring optimizer state from checkpoint...')
    #             optimizer.load_state_dict(checkpoint['optimizer'])

    #         if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
    #             if log_info:
    #                 _logger.info('Restoring AMP loss scaler state from checkpoint...')
    #             loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])

    #         if 'epoch' in checkpoint:
    #             resume_epoch = checkpoint['epoch']
    #             if 'version' in checkpoint and checkpoint['version'] > 1:
    #                 resume_epoch += 1  # start at the next epoch, old checkpoints incremented before save

    #         if log_info:
    #             _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
    #     else:
    #         model.load_state_dict(checkpoint)
    #         if log_info:
    #             _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
    #     return resume_epoch
    # else:
    #     _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
    #     raise FileNotFoundError()


def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_frames=8, num_patches=196, attention_type='divided_space_time', pretrained_model="", strict=True):
    if cfg is None:
        cfg = getattr(model, 'default_cfg')
    if cfg is None or 'url' not in cfg or not cfg['url']:
        _logger.warning("Pretrained model URL is invalid, using random initialization.")
        return

    if len(pretrained_model) == 0:
        if cfg is None:
            _logger.info(f"loading from default config {model.default_cfg}.")
        state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
    else:
       try:
         state_dict = load_state_dict(pretrained_model)['model']
       except:
         state_dict = load_state_dict(pretrained_model)


    if filter_fn is not None:
        state_dict = filter_fn(state_dict)

    if in_chans == 1:
        conv1_name = cfg['first_conv']
        _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
        conv1_weight = state_dict[conv1_name + '.weight']
        conv1_type = conv1_weight.dtype
        conv1_weight = conv1_weight.float()
        O, I, J, K = conv1_weight.shape
        if I > 3:
            assert conv1_weight.shape[1] % 3 == 0
            # For models with space2depth stems
            conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
            conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
        else:
            conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
        conv1_weight = conv1_weight.to(conv1_type)
        state_dict[conv1_name + '.weight'] = conv1_weight
    elif in_chans != 3:
        conv1_name = cfg['first_conv']
        conv1_weight = state_dict[conv1_name + '.weight']
        conv1_type = conv1_weight.dtype
        conv1_weight = conv1_weight.float()
        O, I, J, K = conv1_weight.shape
        if I != 3:
            _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
            del state_dict[conv1_name + '.weight']
            strict = False
        else:
            _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
            repeat = int(math.ceil(in_chans / 3))
            conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
            conv1_weight *= (3 / float(in_chans))
            conv1_weight = conv1_weight.to(conv1_type)
            state_dict[conv1_name + '.weight'] = conv1_weight


    classifier_name = cfg['classifier']
    if num_classes == 1000 and cfg['num_classes'] == 1001:
        # special case for imagenet trained models with extra background class in pretrained weights
        classifier_weight = state_dict[classifier_name + '.weight']
        state_dict[classifier_name + '.weight'] = classifier_weight[1:]
        classifier_bias = state_dict[classifier_name + '.bias']
        state_dict[classifier_name + '.bias'] = classifier_bias[1:]
    elif num_classes != state_dict[classifier_name + '.weight'].size(0):
        #print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True)
        # completely discard fully connected for all other differences between pretrained and created model
        del state_dict[classifier_name + '.weight']
        del state_dict[classifier_name + '.bias']
        strict = False


    ## Resizing the positional embeddings in case they don't match
    _logger.info(f"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}")
    if num_patches + 1 != state_dict['pos_embed'].size(1):
        pos_embed = state_dict['pos_embed']
        cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1)
        other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2)
        new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')
        new_pos_embed = new_pos_embed.transpose(1, 2)
        new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
        state_dict['pos_embed'] = new_pos_embed

    ## Resizing time embeddings in case they don't match
    if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1):
        _logger.info(f"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}")
        time_embed = state_dict['time_embed'].transpose(1, 2)
        new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest')
        state_dict['time_embed'] = new_time_embed.transpose(1, 2)

    ## Initializing temporal attention
    if attention_type == 'divided_space_time':
        new_state_dict = state_dict.copy()
        for key in state_dict:
            if 'blocks' in key and 'attn' in key:
                new_key = key.replace('attn','temporal_attn')
                if not new_key in state_dict:
                   new_state_dict[new_key] = state_dict[key]
                else:
                   new_state_dict[new_key] = state_dict[new_key]
            if 'blocks' in key and 'norm1' in key:
                new_key = key.replace('norm1','temporal_norm1')
                if not new_key in state_dict:
                   new_state_dict[new_key] = state_dict[key]
                else:
                   new_state_dict[new_key] = state_dict[new_key]
        state_dict = new_state_dict

    ## Loading the weights
    model.load_state_dict(state_dict, strict=False)


def load_pretrained_CLIP_ViT(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs):
    if hvd.rank() == 0:
        _logger.info(f"Loading CLIP ViT-B/16 checkpoints.")
    loaded_state_dict = torch.load(pretrained_model) 

    ## Initializing temporal attention
    new_state_dict = loaded_state_dict.copy()
    for key in loaded_state_dict:
        if 'blocks' in key and 'attn' in key:
            new_key = key.replace('attn','temporal_attn')
            if not new_key in loaded_state_dict:
                new_state_dict[new_key] = loaded_state_dict[key]
            else:
                new_state_dict[new_key] = loaded_state_dict[new_key]
        if 'blocks' in key and 'norm1' in key:
            new_key = key.replace('norm1','temporal_norm1')
            if not new_key in loaded_state_dict:
                new_state_dict[new_key] = loaded_state_dict[key]
            else:
                new_state_dict[new_key] = loaded_state_dict[new_key]

    loaded_state_dict = new_state_dict

    loaded_keys = loaded_state_dict.keys()
    model_keys = model.state_dict().keys()

    load_not_in_model = [k for k in loaded_keys if k not in model_keys]
    model_not_in_load = [k for k in model_keys if k not in loaded_keys]

    toload = dict() 
    mismatched_shape_keys = []
    for k in model_keys:
        if k in loaded_keys:
            if model.state_dict()[k].shape != loaded_state_dict[k].shape:
                mismatched_shape_keys.append(k)
            else:
                toload[k] = loaded_state_dict[k]

    if hvd.rank() == 0:
        _logger.info("Keys in loaded but not in model:")
        _logger.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}")
        _logger.info("Keys in model but not in loaded:")
        _logger.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}")
        _logger.info("Keys in model and loaded, but shape mismatched:")
        _logger.info(f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}")

    model.load_state_dict(toload, strict=False)


def load_pretrained_imagenet(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs):
    import timm

    if hvd.rank() == 0:
        _logger.info(f"Loading vit_base_patch16_224 checkpoints.")
    loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True).state_dict()

    del loaded_state_dict['head.weight']
    del loaded_state_dict['head.bias']

    ## Initializing temporal attention
    new_state_dict = loaded_state_dict.copy()
    for key in loaded_state_dict:
        if 'blocks' in key and 'attn' in key:
            new_key = key.replace('attn','temporal_attn')
            if not new_key in loaded_state_dict:
                new_state_dict[new_key] = loaded_state_dict[key]
            else:
                new_state_dict[new_key] = loaded_state_dict[new_key]
        if 'blocks' in key and 'norm1' in key:
            new_key = key.replace('norm1','temporal_norm1')
            if not new_key in loaded_state_dict:
                new_state_dict[new_key] = loaded_state_dict[key]
            else:
                new_state_dict[new_key] = loaded_state_dict[new_key]

    loaded_state_dict = new_state_dict

    loaded_keys = loaded_state_dict.keys()
    model_keys = model.state_dict().keys()

    load_not_in_model = [k for k in loaded_keys if k not in model_keys]
    model_not_in_load = [k for k in model_keys if k not in loaded_keys]

    toload = dict() 
    mismatched_shape_keys = []
    for k in model_keys:
        if k in loaded_keys:
            if model.state_dict()[k].shape != loaded_state_dict[k].shape:
                mismatched_shape_keys.append(k)
            else:
                toload[k] = loaded_state_dict[k]

    if hvd.rank() == 0:
        _logger.info("Keys in loaded but not in model:")
        _logger.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}")
        _logger.info("Keys in model but not in loaded:")
        _logger.info(f"In total {len(model_not_in_load)}, {s
Download .txt
gitextract_egeq4lt4/

├── .gitignore
├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING-ARCHIVED.md
├── LICENSE
├── README.md
├── SECURITY.md
├── config_release/
│   ├── base_model.json
│   ├── didemo_ret.json
│   ├── msrvtt_qa.json
│   ├── msrvtt_ret.json
│   ├── msvd_qa.json
│   ├── pretrain_alpro.json
│   ├── pretrain_prompter.json
│   ├── timesformer_divst_8x32_224_k600.json
│   └── timesformer_divst_8x32_224_k600_gc.json
├── env/
│   ├── install_pkg.sh
│   └── requirements.txt
├── run_scripts/
│   ├── clear_cuda_cache.sh
│   ├── ft_didemo_ret.sh
│   ├── ft_msrvtt_qa.sh
│   ├── ft_msrvtt_ret.sh
│   ├── ft_msvd_qa.sh
│   ├── inf_didemo_ret.sh
│   ├── inf_msrvtt_qa.sh
│   ├── inf_msrvtt_ret.sh
│   ├── inf_msvd_qa.sh
│   ├── pt_alpro.sh
│   └── pt_prompter.sh
└── src/
    ├── __init__.py
    ├── configs/
    │   └── config.py
    ├── datasets/
    │   ├── data_utils.py
    │   ├── dataloader.py
    │   ├── dataset_base.py
    │   ├── dataset_pretrain_sparse.py
    │   ├── dataset_video_qa.py
    │   ├── dataset_video_retrieval.py
    │   └── randaugment.py
    ├── modeling/
    │   ├── alpro_models.py
    │   ├── timesformer/
    │   │   ├── __init__.py
    │   │   ├── conv2d_same.py
    │   │   ├── features.py
    │   │   ├── helpers.py
    │   │   ├── linear.py
    │   │   ├── operators.py
    │   │   ├── vit.py
    │   │   └── vit_utils.py
    │   ├── transformers.py
    │   └── xbert.py
    ├── optimization/
    │   ├── adamw.py
    │   ├── sched.py
    │   └── utils.py
    ├── pretrain/
    │   ├── run_pretrain_contrastive_only.py
    │   └── run_pretrain_sparse.py
    ├── tasks/
    │   ├── run_video_qa.py
    │   └── run_video_retrieval.py
    └── utils/
        ├── basic_utils.py
        ├── distributed.py
        ├── grad_ckpt.py
        ├── load_save.py
        ├── logger.py
        └── misc.py
Download .txt
SYMBOL INDEX (509 symbols across 30 files)

FILE: src/configs/config.py
  function parse_with_config (line 12) | def parse_with_config(parsed_args):
  class SharedConfigs (line 32) | class SharedConfigs(object):
    method __init__ (line 42) | def __init__(self, desc="shared config for pretraining and finetuning"):
    method parse_args (line 213) | def parse_args(self):
    method get_sparse_pretraining_args (line 244) | def get_sparse_pretraining_args(self):
    method get_video_retrieval_args (line 279) | def get_video_retrieval_args(self):
    method get_nlvl_args (line 287) | def get_nlvl_args(self):
    method get_vqa_args (line 293) | def get_vqa_args(self):
    method get_video_qa_args (line 308) | def get_video_qa_args(self):

FILE: src/datasets/data_utils.py
  function mask_batch_text_tokens (line 23) | def mask_batch_text_tokens(
  function select_batch_text_pivots (line 73) | def select_batch_text_pivots(
  function image_to_tensor (line 182) | def image_to_tensor(image: np.ndarray, keepdim: bool = True) -> torch.Te...
  function get_padding (line 221) | def get_padding(image, max_w, max_h, pad_all=False):
  class ImagePad (line 245) | class ImagePad(object):
    method __init__ (line 246) | def __init__(self, max_w, max_h, fill=0, padding_mode='constant'):
    method __call__ (line 254) | def __call__(self, img):
    method __repr__ (line 271) | def __repr__(self):
  function get_resize_size (line 276) | def get_resize_size(image, max_size):
  class VideoRandomSquareCrop (line 310) | class VideoRandomSquareCrop(object):
    method __init__ (line 311) | def __init__(self, crop_size, p=0.5):
    method __call__ (line 316) | def __call__(self, video):
  class VideoResizeSquare (line 342) | class VideoResizeSquare(object):
    method __init__ (line 343) | def __init__(self, out_size, interpolation='nearest'):
    method __call__ (line 348) | def __call__(self, video):
    method __repr__ (line 378) | def __repr__(self):
  class ImageResize (line 383) | class ImageResize(object):
    method __init__ (line 396) | def __init__(self, max_size, interpolation=Image.BILINEAR):
    method __call__ (line 401) | def __call__(self, img):
    method __repr__ (line 417) | def __repr__(self):
  function get_imagenet_transform (line 423) | def get_imagenet_transform(min_size=600, max_size=1000):
  class ImageNorm (line 437) | class ImageNorm(object):
    method __init__ (line 440) | def __init__(self, mean, std):
    method __call__ (line 447) | def __call__(self, img):
  function chunk_list (line 460) | def chunk_list(examples, chunk_size=2, pad_to_divisible=True):
  function mk_input_group (line 488) | def mk_input_group(key_grouped_examples, max_n_example_per_group=1, is_t...

FILE: src/datasets/dataloader.py
  class MetaLoader (line 14) | class MetaLoader(object):
    method __init__ (line 16) | def __init__(self, loaders, accum_steps=1, distributed=False):
    method __iter__ (line 38) | def __iter__(self):
  function move_to_cuda (line 59) | def move_to_cuda(batch):
  function record_cuda_stream (line 73) | def record_cuda_stream(batch):
  class PrefetchLoader (line 86) | class PrefetchLoader(object):
    method __init__ (line 91) | def __init__(self, loader, img_normalize=None):
    method __iter__ (line 96) | def __iter__(self):
    method __len__ (line 122) | def __len__(self):
    method preload (line 125) | def preload(self, it):
    method next (line 150) | def next(self, it):
    method __getattr__ (line 158) | def __getattr__(self, name):
  class InfiniteIterator (line 163) | class InfiniteIterator(object):
    method __init__ (line 165) | def __init__(self, iterable):
    method __iter__ (line 169) | def __iter__(self):

FILE: src/datasets/dataset_base.py
  class AlproBaseDataset (line 18) | class AlproBaseDataset(Dataset):
    method __init__ (line 35) | def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type='lmd...
    method __len__ (line 62) | def __len__(self):
    method __getitem__ (line 65) | def __getitem__(self, index):
    method _load_img (line 68) | def _load_img(self, img_id):
    method _is_extreme_aspect_ratio (line 86) | def _is_extreme_aspect_ratio(cls, tensor, max_ratio=5.):
    method _load_video (line 95) | def _load_video(self, video_id, num_clips=None, clip_idx=None,
    method _load_video_from_path_decord (line 137) | def _load_video_from_path_decord(self, video_path, height=None, width=...
  function img_collate (line 184) | def img_collate(imgs):

FILE: src/datasets/dataset_pretrain_sparse.py
  class AlproPretrainSparseDataset (line 22) | class AlproPretrainSparseDataset(AlproBaseDataset):
    method __init__ (line 36) | def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type, txt...
    method __len__ (line 65) | def __len__(self):
    method __getitem__ (line 68) | def __getitem__(self, index):
  class PretrainImageTextDataset (line 125) | class PretrainImageTextDataset(Dataset):
    method __init__ (line 126) | def __init__(self, datalist, tokenizer, is_train=True, crop_size=256, ...
    method __len__ (line 143) | def __len__(self):
    method __getitem__ (line 146) | def __getitem__(self, index):
  class PretrainCollator (line 196) | class PretrainCollator(object):
    method __init__ (line 200) | def __init__(self, tokenizer,
    method collate_batch (line 214) | def collate_batch(self, batch):
  function random_erase (line 277) | def random_erase(input_img, patch_size, s_l=0.3, s_h=0.5, r_1=0.3, r_2=1...

FILE: src/datasets/dataset_video_qa.py
  class AlproVideoQADataset (line 13) | class AlproVideoQADataset(AlproBaseDataset):
    method __init__ (line 28) | def __init__(self, task_type, datalist, tokenizer, img_lmdb_dir,
    method __len__ (line 55) | def __len__(self):
    method __getitem__ (line 59) | def __getitem__(self, index):
    method _get_single_example (line 89) | def _get_single_example(self, data):
    method evaluate_qa (line 102) | def evaluate_qa(self, results):
  class VideoQACollator (line 158) | class VideoQACollator(object):
    method __init__ (line 159) | def __init__(self, tokenizer, max_length=20, task_type="action", n_opt...
    method collate_batch (line 165) | def collate_batch(self, batch):

FILE: src/datasets/dataset_video_retrieval.py
  class AlproVideoRetrievalDataset (line 13) | class AlproVideoRetrievalDataset(AlproBaseDataset):
    method __init__ (line 22) | def __init__(self, datalist, tokenizer, img_lmdb_dir,
    method __len__ (line 47) | def __len__(self):
    method __getitem__ (line 50) | def __getitem__(self, index):
    method _get_single_example (line 83) | def _get_single_example(self, data, index):
  class VideoRetrievalCollator (line 95) | class VideoRetrievalCollator(object):
    method __init__ (line 96) | def __init__(self, tokenizer, max_length=40):
    method collate_batch (line 100) | def collate_batch(self, batch):
  class AlproVideoRetrievalEvalDataset (line 143) | class AlproVideoRetrievalEvalDataset(AlproBaseDataset):
    method __init__ (line 153) | def __init__(self, datalist, tokenizer, img_lmdb_dir,
    method __len__ (line 174) | def __len__(self):
    method __getitem__ (line 177) | def __getitem__(self, index):
    method _prepare_batches_by_video (line 198) | def _prepare_batches_by_video(self):

FILE: src/datasets/randaugment.py
  function identity_func (line 7) | def identity_func(img):
  function autocontrast_func (line 11) | def autocontrast_func(img, cutoff=0):
  function equalize_func (line 44) | def equalize_func(img):
  function rotate_func (line 67) | def rotate_func(img, degree, fill=(0, 0, 0)):
  function horizontal_flip_func (line 78) | def horizontal_flip_func(img):
  function solarize_func (line 88) | def solarize_func(img, thresh=128):
  function color_func (line 98) | def color_func(img, factor):
  function contrast_func (line 120) | def contrast_func(img, factor):
  function brightness_func (line 133) | def brightness_func(img, factor):
  function sharpness_func (line 142) | def sharpness_func(img, factor):
  function shear_x_func (line 163) | def shear_x_func(img, factor, fill=(0, 0, 0)):
  function translate_x_func (line 170) | def translate_x_func(img, offset, fill=(0, 0, 0)):
  function translate_y_func (line 180) | def translate_y_func(img, offset, fill=(0, 0, 0)):
  function posterize_func (line 190) | def posterize_func(img, bits):
  function shear_y_func (line 198) | def shear_y_func(img, factor, fill=(0, 0, 0)):
  function enhance_level_to_args (line 219) | def enhance_level_to_args(MAX_LEVEL):
  function shear_level_to_args (line 225) | def shear_level_to_args(MAX_LEVEL, replace_value):
  function translate_level_to_args (line 234) | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
  function cutout_level_to_args (line 243) | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
  function solarize_level_to_args (line 251) | def solarize_level_to_args(MAX_LEVEL):
  function none_level_to_args (line 258) | def none_level_to_args(level):
  function posterize_level_to_args (line 262) | def posterize_level_to_args(MAX_LEVEL):
  function rotate_level_to_args (line 269) | def rotate_level_to_args(MAX_LEVEL, replace_value):
  class TemporalConsistentRandomAugment (line 323) | class TemporalConsistentRandomAugment(object):
    method __init__ (line 325) | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
    method get_random_ops (line 335) | def get_random_ops(self):
    method __call__ (line 340) | def __call__(self, frames):
    method _aug (line 355) | def _aug(self, img, ops, apply_or_not):
  class RandomAugment (line 363) | class RandomAugment(object):
    method __init__ (line 365) | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
    method get_random_ops (line 374) | def get_random_ops(self):
    method __call__ (line 378) | def __call__(self, img):
  function save_frames_grid (line 390) | def save_frames_grid(img_array, out_path):
  function stack (line 415) | def stack(data, dim=0):

FILE: src/modeling/alpro_models.py
  class AlproBaseModel (line 19) | class AlproBaseModel(nn.Module):
    method __init__ (line 20) | def __init__(self, config=None, input_format='RGB', video_enc_cfg=None...
    method load_separate_ckpt (line 45) | def load_separate_ckpt(self, visual_weights_path=None, bert_weights_pa...
  class AlproForPretrain (line 58) | class AlproForPretrain(AlproBaseModel):
    method __init__ (line 59) | def __init__(self, config, video_enc_cfg, input_format='RGB'):
    method build_text_prompts (line 73) | def build_text_prompts(self, prompts):
    method get_pseudo_labels (line 76) | def get_pseudo_labels(self, batch):
    method forward (line 79) | def forward(self, batch):
    method _forward_visual_embeds (line 186) | def _forward_visual_embeds(self, visual_inputs):
    method _forward_text_feats (line 196) | def _forward_text_feats(self, batch):
    method compute_mpm_with_encoder_out (line 209) | def compute_mpm_with_encoder_out(self, encoder_outputs, text_atts, sof...
    method compute_mpm (line 234) | def compute_mpm(self, text_embeds, text_atts, image_embeds, image_atts...
    method compute_vtm (line 269) | def compute_vtm(self, text_embeds, text_atts, video_embeds, video_atts...
    method compute_mlm (line 346) | def compute_mlm(self, input_ids, text_input_mask, video_embeds, video_...
    method load_separate_ckpt (line 375) | def load_separate_ckpt(self, visual_weights_path=None, bert_weights_pa...
  class Prompter (line 389) | class Prompter(AlproBaseModel):
    method __init__ (line 390) | def __init__(self, config, video_enc_cfg, input_format='RGB'):
    method load_pretrained_weights_without_prompts (line 404) | def load_pretrained_weights_without_prompts(self, ckpt_path):
    method build_text_prompts (line 430) | def build_text_prompts(self, prompts):
    method _forward_visual_embeds (line 509) | def _forward_visual_embeds(self, visual_inputs):
    method _compute_soft_labels (line 525) | def _compute_soft_labels(self, sim_vp_masked):
    method get_pseudo_labels (line 531) | def get_pseudo_labels(self, batch):
    method forward (line 553) | def forward(self, batch):
    method forward_feats (line 597) | def forward_feats(self, batch):
  class AlproForSequenceClassification (line 633) | class AlproForSequenceClassification(AlproBaseModel):
    method __init__ (line 634) | def __init__(self, config, video_enc_cfg, input_format='RGB'):
    method forward (line 647) | def forward(self, batch):
    method forward_inference (line 691) | def forward_inference(self, batch):
  class AlproForVideoTextRetrieval (line 727) | class AlproForVideoTextRetrieval(AlproBaseModel):
    method __init__ (line 730) | def __init__(self, config, video_enc_cfg, input_format='RGB'):
    method forward (line 733) | def forward(self, batch):
    method compute_vtm (line 800) | def compute_vtm(self, text_embeds, text_atts, image_embeds, image_atts...
    method forward_inference (line 874) | def forward_inference(self, batch):

FILE: src/modeling/timesformer/conv2d_same.py
  function pad_same (line 14) | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value...
  function get_same_padding (line 22) | def get_same_padding(x: int, k: int, s: int, d: int):
  function get_padding_value (line 25) | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bo...
  function conv2d_same (line 47) | def conv2d_same(
  class Conv2dSame (line 54) | class Conv2dSame(nn.Conv2d):
    method __init__ (line 58) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 63) | def forward(self, x):
  function create_conv2d_pad (line 67) | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):

FILE: src/modeling/timesformer/features.py
  class FeatureInfo (line 12) | class FeatureInfo:
    method __init__ (line 14) | def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
    method from_other (line 25) | def from_other(self, out_indices: Tuple[int]):
    method get (line 28) | def get(self, key, idx=None):
    method get_dicts (line 41) | def get_dicts(self, keys=None, idx=None):
    method channels (line 54) | def channels(self, idx=None):
    method reduction (line 59) | def reduction(self, idx=None):
    method module_name (line 64) | def module_name(self, idx=None):
    method __getitem__ (line 69) | def __getitem__(self, item):
    method __len__ (line 72) | def __len__(self):
  class FeatureHooks (line 76) | class FeatureHooks:
    method __init__ (line 83) | def __init__(self, hooks, named_modules, out_map=None, default_hook_ty...
    method _collect_output_hook (line 100) | def _collect_output_hook(self, hook_id, *args):
    method get_output (line 106) | def get_output(self, device) -> Dict[str, torch.tensor]:
  function _module_list (line 112) | def _module_list(module, flatten_sequential=False):
  function _get_feature_info (line 126) | def _get_feature_info(net, out_indices):
  function _get_return_layers (line 136) | def _get_return_layers(feature_info, out_map):
  class FeatureDictNet (line 144) | class FeatureDictNet(nn.ModuleDict):
    method __init__ (line 164) | def __init__(
    method _collect (line 187) | def _collect(self, x) -> (Dict[str, torch.Tensor]):
    method forward (line 201) | def forward(self, x) -> Dict[str, torch.Tensor]:
  class FeatureListNet (line 205) | class FeatureListNet(FeatureDictNet):
    method __init__ (line 210) | def __init__(
    method forward (line 217) | def forward(self, x) -> (List[torch.Tensor]):
  class FeatureHookNet (line 221) | class FeatureHookNet(nn.ModuleDict):
    method __init__ (line 230) | def __init__(
    method forward (line 262) | def forward(self, x):

FILE: src/modeling/timesformer/helpers.py
  function load_state_dict (line 26) | def load_state_dict(checkpoint_path, use_ema=False):
  function load_checkpoint (line 57) | def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
  function load_pretrained (line 102) | def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filte...
  function load_pretrained_CLIP_ViT (line 213) | def load_pretrained_CLIP_ViT(model, pretrained_model, cfg=None, ignore_c...
  function load_pretrained_imagenet (line 262) | def load_pretrained_imagenet(model, pretrained_model, cfg=None, ignore_c...
  function load_pretrained_kinetics (line 315) | def load_pretrained_kinetics(model, pretrained_model, cfg=None, ignore_c...
  function resize_spatial_embedding (line 355) | def resize_spatial_embedding(state_dict, key, num_patches):
  function resize_temporal_embedding (line 370) | def resize_temporal_embedding(state_dict, key, num_frames):

FILE: src/modeling/timesformer/linear.py
  class Linear (line 7) | class Linear(nn.Linear):
    method forward (line 8) | def forward(self, input: torch.Tensor) -> torch.Tensor:

FILE: src/modeling/timesformer/vit.py
  function _cfg (line 30) | def _cfg(url='', **kwargs):
  class Mlp (line 49) | class Mlp(nn.Module):
    method __init__ (line 50) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method forward (line 59) | def forward(self, x):
  class Attention (line 68) | class Attention(nn.Module):
    method __init__ (line 69) | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, at...
    method forward (line 81) | def forward(self, x):
  class Block (line 103) | class Block(nn.Module):
    method __init__ (line 105) | def __init__(self, dim, num_heads, layer_num, mlp_ratio=4., qkv_bias=F...
    method forward (line 136) | def forward(self, x, B, T, W):
  class PatchEmbed (line 216) | class PatchEmbed(nn.Module):
    method __init__ (line 220) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
    method forward (line 233) | def forward(self, x):
  class VisionTransformer (line 242) | class VisionTransformer(nn.Module):
    method __init__ (line 246) | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classe...
    method _init_weights (line 300) | def _init_weights(self, m):
    method no_weight_decay (line 310) | def no_weight_decay(self):
    method get_classifier (line 313) | def get_classifier(self):
    method reset_classifier (line 316) | def reset_classifier(self, num_classes, global_pool=''):
    method forward_features (line 321) | def forward_features(self, x, return_all_tokens=False):
    method forward (line 379) | def forward(self, x):
  function _conv_filter (line 385) | def _conv_filter(state_dict, patch_size=16):
  class vit_base_patch16_224 (line 397) | class vit_base_patch16_224(nn.Module):
    method __init__ (line 398) | def __init__(self, cfg, **kwargs):
    method forward (line 414) | def forward(self, x):
  class TimeSformer (line 419) | class TimeSformer(nn.Module):
    method __init__ (line 420) | def __init__(self, model_cfg, input_format='BGR', cross_attention_conf...
    method forward (line 471) | def forward(self, x):
    method forward_features (line 475) | def forward_features(self, x, return_all_tokens=True, pooling='tempora...
    method _get_pooled_features (line 505) | def _get_pooled_features(self, x):
    method load_state_dict (line 515) | def load_state_dict(self, pretrained_ckpt_path):

FILE: src/modeling/timesformer/vit_utils.py
  function _no_grad_trunc_normal_ (line 23) | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
  function trunc_normal_ (line 56) | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
  function _ntuple (line 77) | def _ntuple(n):
  function get_padding (line 86) | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **...
  function get_padding_value (line 90) | def get_padding_value(padding, kernel_size, **kwargs):
  function get_same_padding (line 113) | def get_same_padding(x: int, k: int, s: int, d: int):
  function is_static_pad (line 118) | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, ...
  function pad_same (line 124) | def pad_same(x, k, s, d=(1, 1), value= 0):
  function adaptive_pool_feat_mult (line 131) | def adaptive_pool_feat_mult(pool_type='avg'):
  function drop_path (line 137) | def drop_path(x, drop_prob: float = 0., training: bool = False):
  class DropPath (line 154) | class DropPath(nn.Module):
    method __init__ (line 157) | def __init__(self, drop_prob=None):
    method forward (line 161) | def forward(self, x):

FILE: src/modeling/transformers.py
  function load_tf_weights_in_bert (line 64) | def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
  function mish (line 140) | def mish(x):
  class BertEmbeddings (line 151) | class BertEmbeddings(nn.Module):
    method __init__ (line 155) | def __init__(self, config):
    method forward (line 172) | def forward(self, input_ids=None, token_type_ids=None,
  class BertSelfAttention (line 202) | class BertSelfAttention(nn.Module):
    method __init__ (line 203) | def __init__(self, config):
    method transpose_for_scores (line 224) | def transpose_for_scores(self, x):
    method forward (line 230) | def forward(
  class BertSelfOutput (line 289) | class BertSelfOutput(nn.Module):
    method __init__ (line 290) | def __init__(self, config):
    method forward (line 297) | def forward(self, hidden_states, input_tensor):
  class BertAttention (line 304) | class BertAttention(nn.Module):
    method __init__ (line 305) | def __init__(self, config):
    method prune_heads (line 311) | def prune_heads(self, heads):
    method forward (line 337) | def forward(
  class BertIntermediate (line 354) | class BertIntermediate(nn.Module):
    method __init__ (line 355) | def __init__(self, config):
    method forward (line 363) | def forward(self, hidden_states):
  class BertOutput (line 369) | class BertOutput(nn.Module):
    method __init__ (line 370) | def __init__(self, config):
    method forward (line 377) | def forward(self, hidden_states, input_tensor):
  class BertLayer (line 384) | class BertLayer(nn.Module):
    method __init__ (line 385) | def __init__(self, config):
    method forward (line 394) | def forward(
  class BertEncoder (line 421) | class BertEncoder(nn.Module):
    method __init__ (line 422) | def __init__(self, config):
    method forward (line 429) | def forward(
  class BertPooler (line 464) | class BertPooler(nn.Module):
    method __init__ (line 465) | def __init__(self, config):
    method forward (line 470) | def forward(self, hidden_states):
  class BertPredictionHeadTransform (line 479) | class BertPredictionHeadTransform(nn.Module):
    method __init__ (line 480) | def __init__(self, config):
    method forward (line 490) | def forward(self, hidden_states):
  class BertLMPredictionHead (line 497) | class BertLMPredictionHead(nn.Module):
    method __init__ (line 498) | def __init__(self, config):
    method forward (line 512) | def forward(self, hidden_states):
  class BertOnlyMLMHead (line 518) | class BertOnlyMLMHead(nn.Module):
    method __init__ (line 519) | def __init__(self, config):
    method forward (line 523) | def forward(self, sequence_output):
  class BertOnlyNSPHead (line 528) | class BertOnlyNSPHead(nn.Module):
    method __init__ (line 529) | def __init__(self, config):
    method forward (line 533) | def forward(self, pooled_output):
  class BertPreTrainingHeads (line 538) | class BertPreTrainingHeads(nn.Module):
    method __init__ (line 539) | def __init__(self, config):
    method forward (line 544) | def forward(self, sequence_output, pooled_output):
  class BertPreTrainedModel (line 550) | class BertPreTrainedModel(PreTrainedModel):
    method _init_weights (line 559) | def _init_weights(self, module):

FILE: src/modeling/xbert.py
  function load_tf_weights_in_bert (line 92) | def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
  class BertEmbeddings (line 166) | class BertEmbeddings(nn.Module):
    method __init__ (line 169) | def __init__(self, config):
    method forward (line 186) | def forward(
  class BertSelfAttention (line 216) | class BertSelfAttention(nn.Module):
    method __init__ (line 217) | def __init__(self, config, is_cross_attention):
    method save_attn_gradients (line 245) | def save_attn_gradients(self, attn_gradients):
    method get_attn_gradients (line 248) | def get_attn_gradients(self):
    method save_attention_map (line 251) | def save_attention_map(self, attention_map):
    method get_attention_map (line 254) | def get_attention_map(self):
    method transpose_for_scores (line 258) | def transpose_for_scores(self, x):
    method forward (line 263) | def forward(
  class BertSelfOutput (line 349) | class BertSelfOutput(nn.Module):
    method __init__ (line 350) | def __init__(self, config):
    method forward (line 356) | def forward(self, hidden_states, input_tensor):
  class BertAttention (line 363) | class BertAttention(nn.Module):
    method __init__ (line 364) | def __init__(self, config, is_cross_attention=False):
    method prune_heads (line 370) | def prune_heads(self, heads):
    method forward (line 388) | def forward(
  class BertIntermediate (line 412) | class BertIntermediate(nn.Module):
    method __init__ (line 413) | def __init__(self, config):
    method forward (line 421) | def forward(self, hidden_states):
  class BertOutput (line 427) | class BertOutput(nn.Module):
    method __init__ (line 428) | def __init__(self, config):
    method forward (line 434) | def forward(self, hidden_states, input_tensor):
  class BertLayer (line 441) | class BertLayer(nn.Module):
    method __init__ (line 442) | def __init__(self, config, layer_num):
    method forward (line 457) | def forward(
    method feed_forward_chunk (line 516) | def feed_forward_chunk(self, attention_output):
  class BertEncoder (line 522) | class BertEncoder(nn.Module):
    method __init__ (line 523) | def __init__(self, config):
    method forward (line 528) | def forward(
  class BertPooler (line 633) | class BertPooler(nn.Module):
    method __init__ (line 634) | def __init__(self, config):
    method forward (line 639) | def forward(self, hidden_states):
  class BertPredictionHeadTransform (line 648) | class BertPredictionHeadTransform(nn.Module):
    method __init__ (line 649) | def __init__(self, config):
    method forward (line 658) | def forward(self, hidden_states):
  class BertLMPredictionHead (line 665) | class BertLMPredictionHead(nn.Module):
    method __init__ (line 666) | def __init__(self, config):
    method forward (line 679) | def forward(self, hidden_states):
  class BertOnlyMLMHead (line 685) | class BertOnlyMLMHead(nn.Module):
    method __init__ (line 686) | def __init__(self, config):
    method forward (line 690) | def forward(self, sequence_output):
  class BertOnlyNSPHead (line 695) | class BertOnlyNSPHead(nn.Module):
    method __init__ (line 696) | def __init__(self, config):
    method forward (line 700) | def forward(self, pooled_output):
  class BertPreTrainingHeads (line 705) | class BertPreTrainingHeads(nn.Module):
    method __init__ (line 706) | def __init__(self, config):
    method forward (line 711) | def forward(self, sequence_output, pooled_output):
  class BertPreTrainedModel (line 717) | class BertPreTrainedModel(PreTrainedModel):
    method _init_weights (line 728) | def _init_weights(self, module):
  class BertForPreTrainingOutput (line 742) | class BertForPreTrainingOutput(ModelOutput):
  class BertModel (line 832) | class BertModel(BertPreTrainedModel):
    method __init__ (line 842) | def __init__(self, config, add_pooling_layer=True):
    method get_input_embeddings (line 855) | def get_input_embeddings(self):
    method set_input_embeddings (line 858) | def set_input_embeddings(self, value):
    method _prune_heads (line 861) | def _prune_heads(self, heads_to_prune):
    method get_extended_attention_mask (line 878) | def get_extended_attention_mask(self, attention_mask: Tensor, input_sh...
    method forward (line 940) | def forward(
  class BertForPreTraining (line 1091) | class BertForPreTraining(BertPreTrainedModel):
    method __init__ (line 1092) | def __init__(self, config):
    method get_output_embeddings (line 1100) | def get_output_embeddings(self):
    method set_output_embeddings (line 1103) | def set_output_embeddings(self, new_embeddings):
    method forward (line 1108) | def forward(
  class BertLMHeadModel (line 1185) | class BertLMHeadModel(BertPreTrainedModel):
    method __init__ (line 1190) | def __init__(self, config):
    method get_output_embeddings (line 1198) | def get_output_embeddings(self):
    method set_output_embeddings (line 1201) | def set_output_embeddings(self, new_embeddings):
    method forward (line 1206) | def forward(
    method prepare_inputs_for_generation (line 1316) | def prepare_inputs_for_generation(self, input_ids, past=None, attentio...
    method _reorder_cache (line 1335) | def _reorder_cache(self, past, beam_idx):
  class BertForMaskedLM (line 1343) | class BertForMaskedLM(BertPreTrainedModel):
    method __init__ (line 1348) | def __init__(self, config):
    method get_output_embeddings (line 1356) | def get_output_embeddings(self):
    method set_output_embeddings (line 1359) | def set_output_embeddings(self, new_embeddings):
    method forward (line 1369) | def forward(
    method prepare_inputs_for_generation (line 1443) | def prepare_inputs_for_generation(self, input_ids, attention_mask=None...
  class BertForNextSentencePrediction (line 1462) | class BertForNextSentencePrediction(BertPreTrainedModel):
    method __init__ (line 1463) | def __init__(self, config):
    method forward (line 1473) | def forward(
  class BertForSequenceClassification (line 1556) | class BertForSequenceClassification(BertPreTrainedModel):
    method __init__ (line 1557) | def __init__(self, config):
    method forward (line 1574) | def forward(
  class BertForMultipleChoice (line 1641) | class BertForMultipleChoice(BertPreTrainedModel):
    method __init__ (line 1642) | def __init__(self, config):
    method forward (line 1658) | def forward(
  class BertForTokenClassification (line 1732) | class BertForTokenClassification(BertPreTrainedModel):
    method __init__ (line 1736) | def __init__(self, config):
    method forward (line 1753) | def forward(
  class BertForQuestionAnswering (line 1823) | class BertForQuestionAnswering(BertPreTrainedModel):
    method __init__ (line 1827) | def __init__(self, config):
    method forward (line 1843) | def forward(

FILE: src/optimization/adamw.py
  class AdamW (line 11) | class AdamW(Optimizer):
    method __init__ (line 22) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
    method step (line 40) | def step(self, closure=None):

FILE: src/optimization/sched.py
  function noam_schedule (line 8) | def noam_schedule(step, warmup_step=4000):
  function warmup_linear (line 14) | def warmup_linear(step, warmup_step, tot_step):
  function multi_step_schedule (line 20) | def multi_step_schedule(n_epoch, milestones, gamma=0.5):
  function get_lr_sched (line 28) | def get_lr_sched(global_step, decay, learning_rate,

FILE: src/optimization/utils.py
  function setup_e2e_optimizer (line 5) | def setup_e2e_optimizer(model, opts):

FILE: src/pretrain/run_pretrain_contrastive_only.py
  function mk_captions_pretrain_dataloader (line 37) | def mk_captions_pretrain_dataloader(dataset_name, anno_path, video_dir, ...
  function setup_dataloaders (line 104) | def setup_dataloaders(cfg, tokenizer):
  function setup_model (line 125) | def setup_model(cfg, device=None):
  function forward_step (line 185) | def forward_step(cfg, model, batch):
  function validate (line 195) | def validate(model, val_loader, cfg):
  function start_training (line 273) | def start_training():

FILE: src/pretrain/run_pretrain_sparse.py
  function mk_captions_pretrain_dataloader (line 37) | def mk_captions_pretrain_dataloader(dataset_name, anno_path, video_dir, ...
  function setup_dataloaders (line 105) | def setup_dataloaders(cfg, tokenizer):
  function setup_model (line 126) | def setup_model(cfg, device=None):
  function forward_step (line 184) | def forward_step(cfg, model, batch):
  function validate (line 194) | def validate(model, val_loader, cfg):
  function get_video_prompt_templates (line 326) | def get_video_prompt_templates():
  function get_image_prompt_templates (line 344) | def get_image_prompt_templates():
  function setup_text_prompts (line 365) | def setup_text_prompts(cfg, tokenizer):
  function start_training (line 404) | def start_training():

FILE: src/tasks/run_video_qa.py
  function mk_qa_dataloader (line 36) | def mk_qa_dataloader(task_type, anno_path, lmdb_dir, cfg, tokenizer,
  function setup_dataloaders (line 136) | def setup_dataloaders(cfg, tokenizer):
  function setup_model (line 156) | def setup_model(cfg, device=None):
  function forward_step (line 214) | def forward_step(model, batch, cfg):
  function validate (line 225) | def validate(model, val_loader, cfg, train_global_step, eval_score=True):
  function start_training (line 373) | def start_training(cfg):
  function start_inference (line 567) | def start_inference(cfg):

FILE: src/tasks/run_video_retrieval.py
  function mk_video_ret_datalist (line 40) | def mk_video_ret_datalist(raw_datalist, cfg):
  function mk_video_ret_dataloader (line 69) | def mk_video_ret_dataloader(anno_path, lmdb_dir, cfg, tokenizer, is_trai...
  function mk_video_ret_eval_dataloader (line 130) | def mk_video_ret_eval_dataloader(anno_path, lmdb_dir, cfg, tokenizer):
  function setup_dataloaders (line 176) | def setup_dataloaders(cfg, tokenizer):
  function setup_model (line 194) | def setup_model(cfg, device=None):
  function forward_step (line 244) | def forward_step(model, batch):
  function forward_inference_step (line 249) | def forward_inference_step(model, batch):
  function validate (line 254) | def validate(model, val_loader, eval_loader, cfg, train_global_step, eva...
  function start_training (line 302) | def start_training(cfg):
  function get_retrieval_metric_from_bool_matrix (line 515) | def get_retrieval_metric_from_bool_matrix(bool_matrix):
  function get_retrieval_scores (line 542) | def get_retrieval_scores(score_matrix, gt_row2col_id_mapping, row_idx2id...
  function eval_retrieval (line 559) | def eval_retrieval(vid_txt_score_dicts, gt_txt_id2vid_id, id2data):
  function inference_retrieval (line 633) | def inference_retrieval(model, val_loader, eval_file_path, cfg):
  function start_inference (line 741) | def start_inference(cfg):

FILE: src/utils/basic_utils.py
  function load_pickle (line 10) | def load_pickle(filename):
  function save_pickle (line 15) | def save_pickle(data, filename):
  function load_json (line 20) | def load_json(filename):
  function save_json (line 25) | def save_json(data, filename, save_pretty=False, sort_keys=False):
  function load_jsonl (line 33) | def load_jsonl(filename):
  function save_jsonl (line 38) | def save_jsonl(data, filename):
  function concat_json_list (line 44) | def concat_json_list(filepaths, save_path):
  function save_lines (line 51) | def save_lines(list_of_str, filepath):
  function read_lines (line 56) | def read_lines(filepath):
  function mkdirp (line 61) | def mkdirp(p):
  function flat_list_of_lists (line 66) | def flat_list_of_lists(l):
  function convert_to_seconds (line 71) | def convert_to_seconds(hms_time):
  function get_video_name_from_url (line 80) | def get_video_name_from_url(url):
  function merge_dicts (line 84) | def merge_dicts(list_dicts):
  function l2_normalize_np_array (line 91) | def l2_normalize_np_array(np_array, eps=1e-5):
  function make_zipfile (line 96) | def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None...
  class AverageMeter (line 127) | class AverageMeter(object):
    method __init__ (line 129) | def __init__(self):
    method reset (line 138) | def reset(self):
    method update (line 146) | def update(self, val, n=1):
  function dissect_by_lengths (line 155) | def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True):
  function get_ratio_from_counter (line 174) | def get_ratio_from_counter(counter_obj, threshold=200):
  function get_rounded_percentage (line 181) | def get_rounded_percentage(float_number, n_floats=2):
  function read_dataframe (line 185) | def read_dataframe(pkl_path):
  function save_frames_grid (line 189) | def save_frames_grid(img_array, out_path):

FILE: src/utils/distributed.py
  function all_reduce_and_rescale_tensors (line 17) | def all_reduce_and_rescale_tensors(tensors, rescale_denom):
  function all_reduce_and_rescale_tensors_chunked (line 46) | def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom,
  function broadcast_tensors (line 99) | def broadcast_tensors(tensors, root_rank, buffer_size=10485760):
  function all_gather_list (line 149) | def all_gather_list(data, max_size=4096):
  function any_broadcast (line 181) | def any_broadcast(data, root_rank, max_size=4096):
  function allgather_object (line 206) | def allgather_object(obj, name=None):

FILE: src/utils/grad_ckpt.py
  function detach_variable (line 5) | def detach_variable(inputs):
  function check_backward_validity (line 18) | def check_backward_validity(inputs):
  class CheckpointFunction (line 23) | class CheckpointFunction(torch.autograd.Function):
    method forward (line 25) | def forward(ctx, run_function, length, *args):
    method backward (line 34) | def backward(ctx, *output_grads):

FILE: src/utils/load_save.py
  function save_training_meta (line 19) | def save_training_meta(args):
  class ModelSaver (line 45) | class ModelSaver(object):
    method __init__ (line 46) | def __init__(self, output_dir):
    method save (line 50) | def save(self, step, model, optimizer=None, prefix="model"):
  function load_state_dict_with_pos_embed_resizing (line 73) | def load_state_dict_with_pos_embed_resizing(model, loaded_state_dict_or_...
  function compare_dict_difference (line 138) | def compare_dict_difference(dict1, dict2, dict1_name="dict1",
  function _to_cuda (line 179) | def _to_cuda(state):
  function _to_cpu (line 197) | def _to_cpu(state):
  class TrainingRestorer (line 215) | class TrainingRestorer(object):
    method __init__ (line 217) | def __init__(self, opts, **ckpt_dict):
    method step (line 244) | def step(self):
    method save (line 257) | def save(self):
    method restore (line 267) | def restore(self):
  class E2E_TrainingRestorer (line 280) | class E2E_TrainingRestorer(object):
    method __init__ (line 281) | def __init__(self, opts, model, optimizer):
    method step (line 313) | def step(self):
    method save (line 326) | def save(self):
    method restore (line 336) | def restore(self, opts):

FILE: src/utils/logger.py
  function add_log_to_file (line 15) | def add_log_to_file(log_path):
  class TensorboardLogger (line 22) | class TensorboardLogger(object):
    method __init__ (line 23) | def __init__(self):
    method create (line 27) | def create(self, path):
    method noop (line 30) | def noop(self, *args, **kwargs):
    method step (line 33) | def step(self):
    method global_step (line 37) | def global_step(self):
    method global_step (line 41) | def global_step(self, step):
    method log_scalar_dict (line 44) | def log_scalar_dict(self, log_dict, prefix=''):
    method __getattr__ (line 58) | def __getattr__(self, name):
  class RunningMeter (line 67) | class RunningMeter(object):
    method __init__ (line 71) | def __init__(self, name, val=None, smooth=0.99):
    method __call__ (line 76) | def __call__(self, value):
    method __str__ (line 80) | def __str__(self):
    method val (line 84) | def val(self):
    method name (line 88) | def name(self):

FILE: src/utils/misc.py
  class NoOp (line 12) | class NoOp(object):
    method __getattr__ (line 14) | def __getattr__(self, name):
    method noop (line 17) | def noop(self, *args, **kwargs):
  function set_random_seed (line 21) | def set_random_seed(seed):
  function zero_none_grad (line 28) | def zero_none_grad(model):
Condensed preview — 62 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (494K chars).
[
  {
    "path": ".gitignore",
    "chars": 659,
    "preview": ".vscode \n\n# script\ntmp_all/script/\n\n# Philly-realted #\npt/\n.ptconfig\n\n\n\n# Project-related   #\n*/*results*/\n*results*/\ntm"
  },
  {
    "path": "CODEOWNERS",
    "chars": 140,
    "preview": "# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 5156,
    "preview": "# Salesforce Open Source Community Code of Conduct\n\n## About the Code of Conduct\n\nEquality is a core value at Salesforce"
  },
  {
    "path": "CONTRIBUTING-ARCHIVED.md",
    "chars": 131,
    "preview": "# ARCHIVED\n\nThis project is `Archived` and is no longer actively maintained;\nWe are not accepting contributions or Pull "
  },
  {
    "path": "LICENSE",
    "chars": 1481,
    "preview": "Copyright (c) 2021, Salesforce.com, Inc.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with o"
  },
  {
    "path": "README.md",
    "chars": 8799,
    "preview": "# ALPRO (CVPR 22')\n\n## ALPRO is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS), a one-stop "
  },
  {
    "path": "SECURITY.md",
    "chars": 401,
    "preview": "## Security\n\nPlease report any security issue to [security@salesforce.com](mailto:security@salesforce.com)\nas soon as it"
  },
  {
    "path": "config_release/base_model.json",
    "chars": 491,
    "preview": "{\n    \"attention_probs_dropout_prob\": 0.1,\n    \"hidden_act\": \"gelu\",\n    \"hidden_dropout_prob\": 0.1,\n    \"hidden_size\": "
  },
  {
    "path": "config_release/didemo_ret.json",
    "chars": 1212,
    "preview": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"didemo\",\n      \"txt\": \"data/didemo_ret/txt/train.jsonl\",\n      \"img\": \"data"
  },
  {
    "path": "config_release/msrvtt_qa.json",
    "chars": 1370,
    "preview": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"msrvtt_qa\",\n      \"txt\": {\n        \"msrvtt_qa\": \"data/msrvtt_qa/txt/train.j"
  },
  {
    "path": "config_release/msrvtt_ret.json",
    "chars": 1150,
    "preview": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"msrvtt\",\n      \"txt\": \"data/msrvtt_ret/txt/train.jsonl\",\n      \"img\": \"data"
  },
  {
    "path": "config_release/msvd_qa.json",
    "chars": 1401,
    "preview": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"msvd_qa\",\n      \"txt\": {\n        \"msvd_qa\": \"data/msvd_qa/txt/train.jsonl\"\n"
  },
  {
    "path": "config_release/pretrain_alpro.json",
    "chars": 1592,
    "preview": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"webvid2m\",\n      \"ann\": \"data/webvid2m/txt/train.pkl\",\n      \"txt\": null,\n "
  },
  {
    "path": "config_release/pretrain_prompter.json",
    "chars": 1452,
    "preview": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"webvid2m\",\n      \"ann\": \"data/webvid2m/txt/train.pkl\",\n      \"txt\": null,\n "
  },
  {
    "path": "config_release/timesformer_divst_8x32_224_k600.json",
    "chars": 219,
    "preview": "{\n    \"cls\": \"TimeSformer\",\n    \"patch_size\": 16,\n    \"attn_drop_rate\": 0,\n    \"drop_rate\": 0,\n    \"drop_path_rate\": 0.1"
  },
  {
    "path": "config_release/timesformer_divst_8x32_224_k600_gc.json",
    "chars": 218,
    "preview": "{\n    \"cls\": \"TimeSformer\",\n    \"patch_size\": 16,\n    \"attn_drop_rate\": 0,\n    \"drop_rate\": 0,\n    \"drop_path_rate\": 0.1"
  },
  {
    "path": "env/install_pkg.sh",
    "chars": 606,
    "preview": "apt update\napt install lsof\n\n# horovod\nHOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \\\n    "
  },
  {
    "path": "env/requirements.txt",
    "chars": 203,
    "preview": "ipdb\njoblib\ncytoolz\nlz4==2.1.9\nlmdb==0.97\nmsgpack-numpy\nmsgpack\ntoolz\ntransformers==4.11.3\ntensorboard\ntqdm\neasydict\npyc"
  },
  {
    "path": "run_scripts/clear_cuda_cache.sh",
    "chars": 98,
    "preview": "for i in $(lsof /dev/nvidia* | grep python  | awk '{print $2}' | sort -u); do kill -9 $i; done\n\n\n\n"
  },
  {
    "path": "run_scripts/ft_didemo_ret.sh",
    "chars": 302,
    "preview": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/didemo_ret.json'\n\nhorovodrun "
  },
  {
    "path": "run_scripts/ft_msrvtt_qa.sh",
    "chars": 293,
    "preview": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/msrvtt_qa.json'\n\nhorovodrun -"
  },
  {
    "path": "run_scripts/ft_msrvtt_ret.sh",
    "chars": 301,
    "preview": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/msrvtt_ret.json'\n\nhorovodrun "
  },
  {
    "path": "run_scripts/ft_msvd_qa.sh",
    "chars": 289,
    "preview": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/msvd_qa.json'\n\nhorovodrun -np"
  },
  {
    "path": "run_scripts/inf_didemo_ret.sh",
    "chars": 536,
    "preview": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nSTEP='best'\n\nCONFIG_PATH='config_release/didemo_ret.json'"
  },
  {
    "path": "run_scripts/inf_msrvtt_qa.sh",
    "chars": 525,
    "preview": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nSTEP='best'\n\nCONFIG_PATH='config_release/msrvtt_qa.json'\n"
  },
  {
    "path": "run_scripts/inf_msrvtt_ret.sh",
    "chars": 536,
    "preview": "cd ..\n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nSTEP='best'\n\nCONFIG_PATH='config_release/msrvtt_ret.json'\n"
  },
  {
    "path": "run_scripts/inf_msvd_qa.sh",
    "chars": 517,
    "preview": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nSTEP='best'\n\nCONFIG_PATH='config_release/msvd_qa.json'\n\nT"
  },
  {
    "path": "run_scripts/pt_alpro.sh",
    "chars": 291,
    "preview": "cd ..\n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/pretrain_alpro.json'\n\nhorovodr"
  },
  {
    "path": "run_scripts/pt_prompter.sh",
    "chars": 310,
    "preview": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/pretrain_prompter.json'\n\nhoro"
  },
  {
    "path": "src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/configs/config.py",
    "chars": 16838,
    "preview": "\"\"\"\nModified from UNITER code\n\"\"\"\nimport os\nimport sys\nimport json\nimport argparse\n\nfrom easydict import EasyDict as edi"
  },
  {
    "path": "src/datasets/data_utils.py",
    "chars": 20814,
    "preview": "import torch\nimport random\nimport torchvision.transforms as transforms\nfrom torchvision.transforms.functional import pad"
  },
  {
    "path": "src/datasets/dataloader.py",
    "chars": 6275,
    "preview": "\"\"\"\nmodified from UNITER codebase\n\nA meta data loader for sampling from different datasets / training tasks\nA prefetch l"
  },
  {
    "path": "src/datasets/dataset_base.py",
    "chars": 7672,
    "preview": "from torch.utils.data import Dataset\nfrom PIL import Image\nimport io\nimport av\nimport torch\nimport numpy as np\nimport lm"
  },
  {
    "path": "src/datasets/dataset_pretrain_sparse.py",
    "chars": 11754,
    "preview": "import os\nimport json\nimport random\n\nimport torch\nimport spacy\nfrom torch.utils.data.dataloader import default_collate\nf"
  },
  {
    "path": "src/datasets/dataset_video_qa.py",
    "chars": 8497,
    "preview": "import os\nimport torch\nimport random\nimport numpy as np\nimport copy\nfrom torch.utils.data.dataloader import default_coll"
  },
  {
    "path": "src/datasets/dataset_video_retrieval.py",
    "chars": 9028,
    "preview": "import random\nimport copy\nimport os\nimport torch\nimport numpy as np\nfrom torch.utils.data.dataloader import default_coll"
  },
  {
    "path": "src/datasets/randaugment.py",
    "chars": 13293,
    "preview": "import cv2\nimport numpy as np\nimport torch\n\n\n## aug functions\ndef identity_func(img):\n    return img\n\n\ndef autocontrast_"
  },
  {
    "path": "src/modeling/alpro_models.py",
    "chars": 40403,
    "preview": "import copy\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom apex.normalization.fused_layer_norm im"
  },
  {
    "path": "src/modeling/timesformer/__init__.py",
    "chars": 241,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n\n# from .build import MODEL_REGISTRY, build_mode"
  },
  {
    "path": "src/modeling/timesformer/conv2d_same.py",
    "chars": 3101,
    "preview": "# Copyright 2020 Ross Wightman\n# Conv2d w/ Same Padding\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional a"
  },
  {
    "path": "src/modeling/timesformer/features.py",
    "chars": 11749,
    "preview": "# Copyright 2020 Ross Wightman\n\nfrom collections import OrderedDict, defaultdict\nfrom copy import deepcopy\nfrom functool"
  },
  {
    "path": "src/modeling/timesformer/helpers.py",
    "chars": 16773,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n# Copyright 2020 Ross Wightman\n# Modified model "
  },
  {
    "path": "src/modeling/timesformer/linear.py",
    "chars": 479,
    "preview": "\"\"\" Linear layer (alternate definition)\n\"\"\"\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn as nn\n\ncla"
  },
  {
    "path": "src/modeling/timesformer/operators.py",
    "chars": 2655,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n\n# \"\"\"Custom operators.\"\"\"\n\n# import torch\n# imp"
  },
  {
    "path": "src/modeling/timesformer/vit.py",
    "chars": 21819,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n# Copyright 2020 Ross Wightman\n# Modified Model "
  },
  {
    "path": "src/modeling/timesformer/vit_utils.py",
    "chars": 6459,
    "preview": "# Copyright 2020 Ross Wightman\n# Various utility functions\n\nimport torch\nimport torch.nn as nn\nfrom functools import par"
  },
  {
    "path": "src/modeling/transformers.py",
    "chars": 22364,
    "preview": "# coding=utf-8\r\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\r\n# Copyright (c) 201"
  },
  {
    "path": "src/modeling/xbert.py",
    "chars": 82185,
    "preview": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n# Copyright (c) 2018,"
  },
  {
    "path": "src/optimization/adamw.py",
    "chars": 4403,
    "preview": "\"\"\"\nAdamW optimizer (weight decay fix)\ncopied from hugginface\n\"\"\"\nimport math\n\nimport torch\nfrom torch.optim import Opti"
  },
  {
    "path": "src/optimization/sched.py",
    "chars": 1515,
    "preview": "\"\"\"\noptimizer learning rate scheduling helpers\n\"\"\"\nfrom math import ceil\nfrom collections import Counter\n\n\ndef noam_sche"
  },
  {
    "path": "src/optimization/utils.py",
    "chars": 456,
    "preview": "from torch.optim import Adam, Adamax, SGD\nfrom src.optimization.adamw import AdamW\n\n\ndef setup_e2e_optimizer(model, opts"
  },
  {
    "path": "src/pretrain/run_pretrain_contrastive_only.py",
    "chars": 18469,
    "preview": "import os\n\nimport torch\nimport time\nimport random\nimport pprint\nimport math\nimport json\nfrom transformers import BertCon"
  },
  {
    "path": "src/pretrain/run_pretrain_sparse.py",
    "chars": 24838,
    "preview": "import os\n\nimport torch\nimport time\nimport random\nimport pprint\nimport math\nimport json\nfrom transformers import BertCon"
  },
  {
    "path": "src/tasks/run_video_qa.py",
    "chars": 26332,
    "preview": "import math\nimport os\nimport random\nimport time\nfrom collections import defaultdict\nfrom os.path import join\n\nimport hor"
  },
  {
    "path": "src/tasks/run_video_retrieval.py",
    "chars": 31252,
    "preview": "import json\nimport math\nimport os\nimport random\nimport time\nfrom collections import defaultdict\nfrom os.path import exis"
  },
  {
    "path": "src/utils/basic_utils.py",
    "chars": 6421,
    "preview": "import os\nimport ujson as json\nimport zipfile\nimport numpy as np\nimport pickle\n\nimport pandas as pd\n\n\ndef load_pickle(fi"
  },
  {
    "path": "src/utils/distributed.py",
    "chars": 7816,
    "preview": "\"\"\"\r\nCopyright (c) Microsoft Corporation.\r\nLicensed under the MIT license.\r\ndistributed API using Horovod\r\nModified from"
  },
  {
    "path": "src/utils/grad_ckpt.py",
    "chars": 1507,
    "preview": "import torch\nimport warnings\n\n\ndef detach_variable(inputs):\n    if isinstance(inputs, tuple):\n        out = []\n        f"
  },
  {
    "path": "src/utils/load_save.py",
    "chars": 14346,
    "preview": "\"\"\"\nsaving utilities\n\"\"\"\nimport json\nimport os\nfrom os.path import dirname, exists, join, realpath\nimport subprocess\nfro"
  },
  {
    "path": "src/utils/logger.py",
    "chars": 2315,
    "preview": "\"\"\"\nreferences: UNITER\n\"\"\"\n\nimport logging\nfrom tensorboardX import SummaryWriter\n\n\n_LOG_FMT = '%(asctime)s - %(levelnam"
  },
  {
    "path": "src/utils/misc.py",
    "chars": 583,
    "preview": "\"\"\"\nmodified from UNITER\n\"\"\"\nimport json\nimport random\nimport sys\n\nimport torch\nimport numpy as np\n\n\nclass NoOp(object):"
  }
]

About this extraction

This page contains the full source code of the salesforce/ALPRO GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 62 files (462.2 KB), approximately 112.4k tokens, and a symbol index with 509 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!