[
  {
    "path": ".gitignore",
    "content": ".vscode \n\n# script\ntmp_all/script/\n\n# Philly-realted #\npt/\n.ptconfig\n\n\n\n# Project-related   #\n*/*results*/\n*results*/\ntmp*/\ncache/*\n*/cache*/\ntmp*.py\n*pickle\n\n# compiled files #\n*.pyc\n\n# Packages #\n############\n# it's better to unpack these files and commit the raw source\n# git has its own built in compression methods\n*.7z\n*.dmg\n*.gz\n*.iso\n*.jar\n*.rar\n*.tar\n*.zip\n\n# Logs and databases #\n######################\n*.log\n*.sql\n*.sqlite\n.ipynb_checkpoints/\n*.swp\n*.vscode/\n*.idea/\n\n# OS generated files #\n######################\n.DS_Store\n.DS_Store?\n._*\n.Spotlight-V100\n.Trashes\nehthumbs.db\nThumbs.db\n\n# project-specific\nimg\ntxt\next\ndata\noutput\nsrc/configs_local\n"
  },
  {
    "path": "CODEOWNERS",
    "content": "# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.\n#ECCN:Open Source\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "# Salesforce Open Source Community Code of Conduct\n\n## About the Code of Conduct\n\nEquality is a core value at Salesforce. We believe a diverse and inclusive\ncommunity fosters innovation and creativity, and are committed to building a\nculture where everyone feels included.\n\nSalesforce open-source projects are committed to providing a friendly, safe, and\nwelcoming environment for all, regardless of gender identity and expression,\nsexual orientation, disability, physical appearance, body size, ethnicity, nationality, \nrace, age, religion, level of experience, education, socioeconomic status, or \nother similar personal characteristics.\n\nThe goal of this code of conduct is to specify a baseline standard of behavior so\nthat people with different social values and communication styles can work\ntogether effectively, productively, and respectfully in our open source community.\nIt also establishes a mechanism for reporting issues and resolving conflicts.\n\nAll questions and reports of abusive, harassing, or otherwise unacceptable behavior\nin a Salesforce open-source project may be reported by contacting the Salesforce\nOpen Source Conduct Committee at ossconduct@salesforce.com.\n\n## Our Pledge\n\nIn the interest of fostering an open and welcoming environment, we as\ncontributors and maintainers pledge to making participation in our project and\nour community a harassment-free experience for everyone, regardless of gender \nidentity and expression, sexual orientation, disability, physical appearance, \nbody size, ethnicity, nationality, race, age, religion, level of experience, education, \nsocioeconomic status, or other similar personal characteristics.\n\n## Our Standards\n\nExamples of behavior that contributes to creating a positive environment\ninclude:\n\n* Using welcoming and inclusive language\n* Being respectful of differing viewpoints and experiences\n* Gracefully accepting constructive criticism\n* Focusing on what is best for the community\n* Showing empathy toward other community members\n\nExamples of unacceptable behavior by participants include:\n\n* The use of sexualized language or imagery and unwelcome sexual attention or\nadvances\n* Personal attacks, insulting/derogatory comments, or trolling\n* Public or private harassment\n* Publishing, or threatening to publish, others' private information—such as\na physical or electronic address—without explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\nprofessional setting\n* Advocating for or encouraging any of the above behaviors\n\n## Our Responsibilities\n\nProject maintainers are responsible for clarifying the standards of acceptable\nbehavior and are expected to take appropriate and fair corrective action in\nresponse to any instances of unacceptable behavior.\n\nProject maintainers have the right and responsibility to remove, edit, or\nreject comments, commits, code, wiki edits, issues, and other contributions\nthat are not aligned with this Code of Conduct, or to ban temporarily or\npermanently any contributor for other behaviors that they deem inappropriate,\nthreatening, offensive, or harmful.\n\n## Scope\n\nThis Code of Conduct applies both within project spaces and in public spaces\nwhen an individual is representing the project or its community. Examples of\nrepresenting a project or community include using an official project email\naddress, posting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event. Representation of a project may be\nfurther defined and clarified by project maintainers.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported by contacting the Salesforce Open Source Conduct Committee \nat ossconduct@salesforce.com. All complaints will be reviewed and investigated \nand will result in a response that is deemed necessary and appropriate to the \ncircumstances. The committee is obligated to maintain confidentiality with \nregard to the reporter of an incident. Further details of specific enforcement \npolicies may be posted separately.\n\nProject maintainers who do not follow or enforce the Code of Conduct in good\nfaith may face temporary or permanent repercussions as determined by other\nmembers of the project's leadership and the Salesforce Open Source Conduct \nCommittee.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],\nversion 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. \nIt includes adaptions and additions from [Go Community Code of Conduct][golang-coc], \n[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].\n\nThis Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].\n\n[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)\n[golang-coc]: https://golang.org/conduct\n[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md\n[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/\n[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/\n\n"
  },
  {
    "path": "CONTRIBUTING-ARCHIVED.md",
    "content": "# ARCHIVED\n\nThis project is `Archived` and is no longer actively maintained;\nWe are not accepting contributions or Pull Requests.\n\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2021, Salesforce.com, Inc.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.\n\n* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."
  },
  {
    "path": "README.md",
    "content": "# ALPRO (CVPR 22')\n\n## ALPRO is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS), a one-stop library for language-vision intelligence!\n\n## Align and Prompt: Video-and-Language Pre-training with Entity Prompts [[Paper](https://arxiv.org/abs/2112.09583)]\n\n[Dongxu Li](https://www.linkedin.com/in/dongxu-li-a8a035110/), [Junnan Li](https://sites.google.com/site/junnanlics), [Hongdong Li](http://users.cecs.anu.edu.au/~hongdong/), [Juan Carlos Niebles](http://www.niebles.net/), [Steven C.H. Hoi](https://sites.google.com/view/stevenhoi/home)\n\n<img src=\"pics/teaser.jpg\" width=\"500\">\n\nOfficial PyTorch code for ALPRO. This repository supports pre-training as well as finetuning on \n- Text-Video Retrieval on MSRVTT and DiDeMo.\n- Video Question Anwsering on MSRVTT and MSVD.\n\n## Requirements\nOur implementation is tested on Ubuntu 20.04.1 with NVIDIA A100 GPUs. Supports for other platforms and hardwares are possible with no warrant. To install the required packages:\n\n```bash\ncd env && bash install_pkg.sh\n```\n\n## Data Preparation \n1. Download Annotations and Pre-trained Checkpoints\n    - [Text annotations](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/data.zip)\n    - [Checkpoints of pre-trained model and finetuned model](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/output.zip)\n    - [Externel resources](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/ext.zip)\n    - unzip `data.zip`, `output.zip`, `ext.zip` under `ALPRO/`.\n \n2. Download raw videos of downstream datasets.\n   - MSRVTT:\n     - download train_val_videos.zip and test_videos.zip from e.g. [here](https://www.mediafire.com/folder/h14iarbs62e7p/shared).\n     - check md5sum:\n\n        ```bash\n        51f2394d279cf84f1642defd9a651e6f  train_val_videos.zip\n        0af68454cec9d586e92805739f3911d0  test_videos.zip\n        ```\n     - unzip all the videos into `data/msrvtt_ret/videos` (10k in total).\n     - create the following soft link:\n\n        ```bash\n        ln -s data/msrvtt_ret/videos data/msrvtt_qa/videos```\n    - MSVD:\n      - download from official release:\n  \n        ```bash\n        wget -nc https://www.cs.utexas.edu/users/ml/clamp/videoDescription/YouTubeClips.tar\n        ```\n      - check md5sum:\n      \n        ```bash\n        9bdb20fcf14d59524a6febca9f6a8d89  YouTubeClips.tar\n        ```\n      - unzip all the videos to `data/msvd_qa/videos` (1,970 videos in total).\n        \n        ```bash\n        mkdir data/msvd_qa/videos/ \n        tar xvf YouTubeClips.tar -C data/msvd_qa/videos --strip-components=1\n        ```\n    - DiDeMo:\n       - Following [instructions](https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md) and download from the official release [here](https://drive.google.com/drive/u/1/folders/1_oyJ5rQiZboipbMl6tkhY8v0s9zDkvJc);\n       - unzip all the videos into `data/didemo_ret/videos`.\n       - Note there might be a couple videos missing. See [here](https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md#getting-the-videos) to download. However, as they account for a small portion of training set, you may feel safe to ignore.\n       - Convert all the DiDeMo videos into `*.mp4` format using e.g. [`ffmpeg`](https://askubuntu.com/questions/396883/how-to-simply-convert-video-files-i-e-mkv-to-mp4).\n       - We obtained 10,463 videos following these steps (with one video `77807177@N00_5753455690_1e04ccb364` missing).\n\n\n\n  3. The directory is expected to be in the structure below:\n      ```bash\n      .\n      |-config_release  # configuration files\n      |-data  # text annotations and raw videos\n      |---didemo_ret\n      |-----txt\n      |-----videos\n      |---msrvtt_qa/...\n      |---msrvtt_ret/...\n      |---msvd_qa/...\n      |-env  # scripts to install packages\n      |-ext  # external resources, e.g. bert tokenizer\n      |-output  # checkpoints for pre-trained/finetuned models\n      |---downstreams\n      |-----didemo_ret\n      |-------public\n      |---------ckpt # official finetuned checkpoints\n      |---------log # inference log\n      |---------results_test\n      |-----------step_best_1_mean\n      |-----msrvtt_qa/...\n      |-----msrvtt_ret/...\n      |-----msvd_qa/...\n      |-run_scripts  # bash scripts to launch experiments\n      |-src  # source code\n      ```\n\n## Inference with Official Checkpoints\n\n  ```bash\n  cd run_scripts\n  bash inf_msrvtt_ret.sh\n  # {'text2video': {'r1': 33.9, 'r5': 60.7, 'r10': 73.2, 'medianR': 3.0, 'meanR': 27.404}}\n  bash inf_didemo_ret.sh\n  # {'text2video': {'r1': 35.9, 'r5': 67.5, 'r10': 78.8, 'medianR': 3.0, 'meanR': 19.125}}\n  bash inf_msrvtt_qa.sh\n  # {'ratios': {'what_ratio': [68.48, 49872], 'who_ratio': [27.99, 20385], 'how_ratio': [2.25, 1640], 'where_ratio': [0.34, 250], 'when_ratio': [0.93, 677]}, 'overall_acc': 42.12, 'what_acc': 36.05, 'who_acc': 52.24, 'how_acc': 85.67, 'where_acc': 42.8, 'when_acc': 78.88}\n  bash inf_msvd_qa.sh\n  # {'ratios': {'what_ratio': [61.93, 8150], 'who_ratio': [34.6, 4554], 'how_ratio': [2.81, 370], 'where_ratio': [0.21, 28], 'when_ratio': [0.44, 58]}, 'overall_acc': 45.91, 'what_acc': 37.02, 'who_acc': 58.59, 'how_acc': 81.62, 'where_acc': 46.43, 'when_acc': 72.41}\n  ```\n\n\n## Downstream Task Finetuning\n  - To finetune on downstream tasks with the pre-trained checkpoint `output/pretrain/alpro_pretrained_ckpt.pt`\n\n    ```bash\n    cd run_scripts\n    bash ft_msrvtt_ret.sh\n    bash ft_didemo_ret.sh\n    bash ft_msrvtt_qa.sh\n    bash ft_msvd_qa.sh\n    ```\n  \n    For example, with MSRVTT retrieval:\n    ```bash\n    cd ALPRO/\n\n    export PYTHONPATH=\"$PYTHONPATH:$PWD\"\n    echo $PYTHONPATH\n\n    CONFIG_PATH='config_release/msrvtt_ret.json'\n\n    horovodrun -np 8 python src/tasks/run_video_retrieval.py \\ # change -np to GPUs numbers.\n        --config $CONFIG_PATH \\\n        --output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_ret/$(date '+%Y%m%d%H%M%S')  # change to your local path to store finetuning ckpts and logs \n    ``` \n - Run inference with locally-finetuned checkpoints.\n   ```bash\n    cd ALPRO/\n\n    export PYTHONPATH=\"$PYTHONPATH:$PWD\"\n    echo $PYTHONPATH\n\n    STEP='best'\n\n    CONFIG_PATH='config_release/msrvtt_ret.json'\n    OUTPUT_DIR='[INPUT_YOUR_OUTPUT_PATH_HERE]'\n\n    TXT_DB='data/msrvtt_ret/txt/test.jsonl'\n    IMG_DB='data/msrvtt_ret/videos'\n\n    horovodrun -np 8 python src/tasks/run_video_retrieval.py \\\n        --do_inference 1 \\\n        --inference_split test \\\n        --inference_model_step $STEP \\\n        --inference_txt_db $TXT_DB \\\n        --inference_img_db $IMG_DB \\\n        --inference_batch_size 64 \\\n        --output_dir $OUTPUT_DIR \\\n        --config $CONFIG_PATH\n   ```  \n   - `OUTPUT_DIR` is the path after the `--output_dir` option in the finetuning script.\n   - `$STEP` is a string, which tells the script to use the checkpoint `$OUTPUT_DIR/ckpt/model_step_$STEP.pt` for inference. \n\n\n## Pretraining\n1. Download [WebVid2M](https://github.com/m-bain/frozen-in-time) and [CC-3M](https://github.com/igorbrigadir/DownloadConceptualCaptions).\n  \n    - Put WebVid2M videos under `data/webvid2m`;\n    - 💡 we downsample webvid2m videos to 10% of the original FPS to speed-up video loading;\n    - change `data/cc3m/txt/cc3m.json` with local image paths.\n\n2. Training Prompter:\n    ```bash\n    cd run_scripts && bash pt_prompter.sh\n    ```   \n\n3. Training video-language model: \n    ```bash\n    cd run_scripts && bash pt_alpro.sh\n    ```\n    If you would like to use custom prompter weight, please change `teacher_weights_path` in `config_release/pretrain_alpro.json`\n4. To finetune with pre-trained checkpoints, please change `e2e_weights_path` in the finetuning config files, e.g. `config_release/msrvtt_ret.json`.\n\n\n## Citation\n\nIf you find ALPRO useful for your research, please consider citing:\n```bibtex\n  @inproceedings{li2021align,\n    title={Align and Prompt: Video-and-Language Pre-training with Entity Prompts},\n    author={Dongxu Li, Junnan Li, Hongdong Li, Juan Carlos Niebles, Steven C.H. Hoi},\n    booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n    year={2022}\n  }\n```\n\n## Acknowledgement\nWe thank members at Salesforce Research for their helpful discussions.\n\nThe implementation of ALPRO relies on resources from [ClipBERT](https://github.com/jayleicn/ClipBERT),\n[transformers](https://github.com/huggingface/transformers), \n[TimeSformer](https://github.com/facebookresearch/TimeSformer/tree/main/timesformer/models), \nThe code is implemented using [PyTorch](https://github.com/pytorch/pytorch), \nwith multi-GPU support from [Horovod](https://github.com/horovod/horovod) and [gradient-checkpoint](https://github.com/csrhddlam/pytorch-checkpoint).  We thank the original authors for their open-sourcing and encourage ALPRO users to cite their works when applicable.\n\n"
  },
  {
    "path": "SECURITY.md",
    "content": "## Security\n\nPlease report any security issue to [security@salesforce.com](mailto:security@salesforce.com)\nas soon as it is discovered. This library limits its runtime dependencies in\norder to reduce the total cost of ownership as much as can be, but all consumers\nshould remain vigilant and have their security stakeholders review all third-party\nproducts (3PP) like this one and their dependencies.\n"
  },
  {
    "path": "config_release/base_model.json",
    "content": "{\n    \"attention_probs_dropout_prob\": 0.1,\n    \"hidden_act\": \"gelu\",\n    \"hidden_dropout_prob\": 0.1,\n    \"hidden_size\": 768,\n    \"initializer_range\": 0.02,\n    \"intermediate_size\": 3072,\n    \"layer_norm_eps\": 1e-12,\n    \"max_position_embeddings\": 512,\n    \"model_type\": \"bert\",\n    \"num_attention_heads\": 12,\n    \"num_hidden_layers\": 12,\n    \"pad_token_id\": 0,\n    \"type_vocab_size\": 2,\n    \"vocab_size\": 30522,\n    \"fusion_layer\": 6,\n    \"encoder_width\": 768,\n    \"itc_token_type\": \"cls\"\n}\n"
  },
  {
    "path": "config_release/didemo_ret.json",
    "content": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"didemo\",\n      \"txt\": \"data/didemo_ret/txt/train.jsonl\",\n      \"img\": \"data/didemo_ret/videos\"\n    }\n  ],\n  \"val_datasets\": [\n    {\n      \"name\": \"didemo_retrieval\",\n      \"txt\": \"data/didemo_ret/txt/val.jsonl\",\n      \"img\": \"data/didemo_ret/videos\"\n    }\n  ],\n  \"max_txt_len\": 50,\n  \"crop_img_size\": 224,\n  \"resize_size\": 256,\n  \"img_pixel_mean\": [0.48145466, 0.4578275, 0.40821073], \n  \"img_pixel_std\": [0.26862954, 0.26130258, 0.27577711],\n  \"img_input_format\": \"RGB\",\n  \"num_frm\": 8,\n  \"train_n_clips\": 1,\n  \"max_n_example_per_group\": 1,\n  \"model_config\": \"config_release/base_model.json\",\n  \"tokenizer_dir\": \"ext/bert-base-uncased/\",\n  \"visual_model_cfg\": \"config_release/timesformer_divst_8x32_224_k600.json\",\n  \"e2e_weights_path\": \"output/pretrain/alpro_pretrained_ckpt.pt\",\n  \"bert_weights_path\": null,\n  \"train_batch_size\": 12,\n  \"val_batch_size\": 12,\n  \"gradient_accumulation_steps\": 1,\n  \"num_train_epochs\": 10,\n  \"min_valid_steps\": 20,\n  \"num_valid\": 20,\n  \"learning_rate\": 4e-5,\n  \"weight_decay\": 1e-3,\n  \"decay\": \"linear\",\n  \"optim\": \"adamw\",\n  \"betas\": [0.9, 0.98],\n  \"dropout\": 0.1,\n  \"grad_norm\": 20.0,\n  \"seed\":42,\n  \"fp16\": 0,\n  \"num_workers\": 4\n}\n"
  },
  {
    "path": "config_release/msrvtt_qa.json",
    "content": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"msrvtt_qa\",\n      \"txt\": {\n        \"msrvtt_qa\": \"data/msrvtt_qa/txt/train.jsonl\"\n      },\n      \"img\": \"data/msrvtt_qa/videos\"\n    }\n  ],\n  \"val_datasets\": [\n    {\n      \"name\": \"msrvtt_qa\",\n      \"txt\": {\n        \"msrvtt_qa\": \"data/msrvtt_qa/txt/val.jsonl\"\n      },\n      \"img\": \"data/msrvtt_qa/videos\"\n    }\n  ],\n  \"ans2label_path\": \"data/msrvtt_qa/txt/train_ans2label.json\",\n  \"max_txt_len\": 40,\n  \"crop_img_size\": 224,\n  \"resize_size\": 256,\n  \"img_pixel_mean\": [0.48145466, 0.4578275, 0.40821073], \n  \"img_pixel_std\": [0.26862954, 0.26130258, 0.27577711],\n  \"img_input_format\": \"RGB\",\n  \"train_n_clips\": 1,\n  \"model_config\": \"config_release/base_model.json\",\n  \"tokenizer_dir\": \"ext/bert-base-uncased/\",\n  \"visual_model_cfg\": \"config_release/timesformer_divst_8x32_224_k600_gc.json\",\n  \"e2e_weights_path\": \"output/pretrain/alpro_pretrained_ckpt.pt\",\n  \"num_frm\": 16,\n  \"train_batch_size\": 12,\n  \"val_batch_size\": 12,\n  \"gradient_accumulation_steps\": 2,\n  \"num_train_epochs\": 10,\n  \"min_valid_steps\": 50,\n  \"num_valid\": 50,\n  \"learning_rate\": 5e-5,\n  \"weight_decay\": 1e-3,\n  \"decay\": \"linear\",\n  \"optim\": \"adamw\",\n  \"betas\": [0.9, 0.98],\n  \"dropout\": 0.1,\n  \"grad_norm\": 5.0,\n  \"cnn_lr_decay\": \"linear\",\n  \"seed\":42,\n  \"fp16\": 0,\n  \"classifier\": \"mlp\",\n  \"cls_hidden_scale\": 2,\n  \"task\": \"msrvtt_qa\",\n  \"num_workers\": 4\n}\n"
  },
  {
    "path": "config_release/msrvtt_ret.json",
    "content": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"msrvtt\",\n      \"txt\": \"data/msrvtt_ret/txt/train.jsonl\",\n      \"img\": \"data/msrvtt_ret/videos\"\n    }\n  ],\n  \"val_datasets\": [\n    {\n      \"name\": \"msrvtt_retrieval\",\n      \"txt\": \"data/msrvtt_ret/txt/val.jsonl\",\n      \"img\": \"data/msrvtt_ret/videos\"\n    }\n  ],\n  \"max_txt_len\": 40,\n  \"crop_img_size\": 224,\n  \"resize_size\": 256,\n  \"img_pixel_mean\": [0.48145466, 0.4578275, 0.40821073], \n  \"img_pixel_std\": [0.26862954, 0.26130258, 0.27577711],\n  \"img_input_format\": \"RGB\",\n  \"train_n_clips\": 1,\n  \"model_config\": \"config_release/base_model.json\",\n  \"tokenizer_dir\": \"ext/bert-base-uncased/\",\n  \"visual_model_cfg\": \"config_release/timesformer_divst_8x32_224_k600.json\",\n  \"e2e_weights_path\": \"output/pretrain/alpro_pretrained_ckpt.pt\",\n  \"num_frm\": 8,\n  \"train_batch_size\": 8,\n  \"val_batch_size\": 8,\n  \"gradient_accumulation_steps\": 1,\n  \"num_train_epochs\": 5,\n  \"min_valid_steps\": 100,\n  \"num_valid\": 20,\n  \"learning_rate\": 2.5e-5,\n  \"weight_decay\": 1e-3,\n  \"decay\": \"linear\",\n  \"optim\": \"adamw\",\n  \"betas\": [0.9, 0.98],\n  \"dropout\": 0.1,\n  \"grad_norm\": 5.0,\n  \"seed\":42,\n  \"fp16\": 0,\n  \"num_workers\": 4\n}\n"
  },
  {
    "path": "config_release/msvd_qa.json",
    "content": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"msvd_qa\",\n      \"txt\": {\n        \"msvd_qa\": \"data/msvd_qa/txt/train.jsonl\"\n      },\n      \"img\": \"data/msvd_qa/videos\"\n    }\n  ],\n  \"val_datasets\": [\n    {\n      \"name\": \"msvd_qa\",\n      \"txt\": {\n        \"msvd_qa\": \"data/msvd_qa/txt/val.jsonl\"\n      },\n      \"img\": \"data/msvd_qa/videos\"\n    }\n  ],\n  \"ans2label_path\": \"data/msvd_qa/txt/train_ans2label.json\",\n  \"num_labels\": 2423,\n  \"max_txt_len\": 40,\n  \"crop_img_size\": 224,\n  \"resize_size\": 256,\n  \"img_pixel_mean\": [0.48145466, 0.4578275, 0.40821073], \n  \"img_pixel_std\": [0.26862954, 0.26130258, 0.27577711],\n  \"img_input_format\": \"RGB\",\n  \"train_n_clips\": 1,\n  \"num_frm\": 16,\n  \"model_config\": \"config_release/base_model.json\",\n  \"tokenizer_dir\": \"ext/bert-base-uncased/\",\n  \"visual_model_cfg\": \"config_release/timesformer_divst_8x32_224_k600_gc.json\",\n  \"e2e_weights_path\": \"output/pretrain/alpro_pretrained_ckpt.pt\",\n  \"train_batch_size\": 12,\n  \"val_batch_size\": 12,\n  \"gradient_accumulation_steps\": 2,\n  \"num_train_epochs\": 15,\n  \"min_valid_steps\": 50,\n  \"num_valid\": 30,\n  \"learning_rate\": 5e-5,\n  \"weight_decay\": 1e-3,\n  \"decay\": \"linear\",\n  \"optim\": \"adamw\",\n  \"betas\": [0.9, 0.98],\n  \"dropout\": 0.1,\n  \"grad_norm\": 20.0,\n  \"cnn_lr_decay\": \"linear\",\n  \"seed\":42,\n  \"fp16\": 0,\n  \"save_steps_ratio\": 0.05,\n  \"classifier\": \"mlp\",\n  \"cls_hidden_scale\": 2,\n  \"task\": \"msvd_qa\",\n  \"num_workers\": 4\n}\n"
  },
  {
    "path": "config_release/pretrain_alpro.json",
    "content": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"webvid2m\",\n      \"ann\": \"data/webvid2m/txt/train.pkl\",\n      \"txt\": null,\n      \"img\": \"data/webvid2m/videos\"\n    },\n    {\n      \"name\": \"cc3m\",\n      \"ann\": \"data/cc3m/txt/cc3m.json\",\n      \"txt\": null,\n      \"img\": null \n    }\n  ],\n  \"val_datasets\": [\n    {\n      \"name\": \"webvid2m\",\n      \"ann\": \"data/webvid2m/txt/val.pkl\",\n      \"txt\": null,\n      \"img\": \"data/webvid2m/videos\"\n    }\n  ],\n  \"img_pixel_mean\": [0.48145466, 0.4578275, 0.40821073], \n  \"img_pixel_std\": [0.26862954, 0.26130258, 0.27577711],\n  \"img_input_format\": \"RGB\",\n  \"model_type\": \"pretrain\",\n  \"model_config\": \"config_release/base_model.json\",\n  \"visual_model_cfg\": \"config_release/timesformer_divst_8x32_224_k600.json\",\n  \"visual_weights_path\": \"vit_base_patch16_224\",\n  \"teacher_weights_path\": \"output/pretrain/prompter_pretrained.pt\",\n  \"entity_file_path\": \"data/unigrams.txt\",\n  \"tokenizer_dir\": \"ext/bert-base-uncased/\",\n  \"max_txt_len\": 30,\n  \"crop_img_size\": 224,\n  \"resize_size\": 256,\n  \"train_batch_size\": 16,\n  \"val_batch_size\": 16,\n  \"gradient_accumulation_steps\": 1,\n  \"num_train_epochs\": 10,\n  \"min_valid_steps\": 10,\n  \"num_valid\": 10,\n  \"learning_rate\": 1e-4,\n  \"decay\": \"linear\",\n  \"optim\": \"adamw\",\n  \"betas\": [0.9, 0.98],\n  \"dropout\": 0.1,\n  \"weight_decay\": 1e-3,\n  \"grad_norm\": 20.0,\n  \"seed\":42,\n  \"fp16\": 0,\n  \"use_itm\": 1,\n  \"use_mlm\": 1,\n  \"use_itc\": 1,\n  \"use_mpm\": 1,\n  \"n_workers\": 4,\n  \"save_steps_ratio\": 0.01,\n  \"frm_sampling_strategy\": \"headtail\",\n  \"num_frm\": 4,\n  \"fps\": 0.5,\n  \"debug\": false,\n  \"warmup_ratio\": 0.05,\n  \"log_interval\": 100\n}\n"
  },
  {
    "path": "config_release/pretrain_prompter.json",
    "content": "{\n  \"train_datasets\": [\n    {\n      \"name\": \"webvid2m\",\n      \"ann\": \"data/webvid2m/txt/train.pkl\",\n      \"txt\": null,\n      \"img\": \"data/webvid2m/videos\"\n    },\n    {\n      \"name\": \"cc3m\",\n      \"ann\": \"data/cc3m/txt/cc3m.json\",\n      \"txt\": null,\n      \"img\": null \n    }\n  ],\n  \"val_datasets\": [\n    {\n      \"name\": \"webvid2m\",\n      \"ann\": \"data/webvid2m/txt/val.pkl\",\n      \"txt\": null,\n      \"img\": \"data/webvid2m/videos\"\n    }\n  ],\n  \"img_pixel_mean\": [0.48145466, 0.4578275, 0.40821073], \n  \"img_pixel_std\": [0.26862954, 0.26130258, 0.27577711],\n  \"img_input_format\": \"RGB\",\n  \"model_type\": \"pretrain\",\n  \"model_config\": \"config_release/base_model.json\",\n  \"visual_model_cfg\": \"config_release/timesformer_divst_8x32_224_k600.json\",\n  \"visual_weights_path\": \"vit_base_patch16_224\",\n  \"tokenizer_dir\": \"ext/bert-base-uncased/\",\n  \"max_txt_len\": 30,\n  \"crop_img_size\": 224,\n  \"resize_size\": 256,\n  \"train_batch_size\": 16,\n  \"val_batch_size\": 16,\n  \"gradient_accumulation_steps\": 2,\n  \"num_train_epochs\": 10,\n  \"min_valid_steps\": 100,\n  \"num_valid\": 10,\n  \"learning_rate\": 1e-4,\n  \"decay\": \"linear\",\n  \"optim\": \"adamw\",\n  \"betas\": [0.9, 0.98],\n  \"dropout\": 0.1,\n  \"weight_decay\": 1e-3,\n  \"grad_norm\": 20.0,\n  \"seed\":42,\n  \"fp16\": 0,\n  \"use_itm\": 0,\n  \"use_mlm\": 0,\n  \"use_itc\": 1,\n  \"n_workers\": 4,\n  \"save_steps_ratio\": 0.05,\n  \"frm_sampling_strategy\": \"headtail\",\n  \"num_frm\": 4,\n  \"debug\": false,\n  \"warmup_ratio\": 0.05,\n  \"log_interval\": 100\n}\n"
  },
  {
    "path": "config_release/timesformer_divst_8x32_224_k600.json",
    "content": "{\n    \"cls\": \"TimeSformer\",\n    \"patch_size\": 16,\n    \"attn_drop_rate\": 0,\n    \"drop_rate\": 0,\n    \"drop_path_rate\": 0.1,\n    \"maxpool_kernel_size\": 2,\n    \"use_maxpooling\": false,\n    \"gradient_checkpointing\": false\n}\n"
  },
  {
    "path": "config_release/timesformer_divst_8x32_224_k600_gc.json",
    "content": "{\n    \"cls\": \"TimeSformer\",\n    \"patch_size\": 16,\n    \"attn_drop_rate\": 0,\n    \"drop_rate\": 0,\n    \"drop_path_rate\": 0.1,\n    \"maxpool_kernel_size\": 2,\n    \"use_maxpooling\": false,\n    \"gradient_checkpointing\": true\n}\n"
  },
  {
    "path": "env/install_pkg.sh",
    "content": "apt update\napt install lsof\n\n# horovod\nHOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \\\n    pip install --no-cache-dir horovod==0.19.4 &&\\\n    ldconfig\n\n# use the faster pillow-simd instead of the original pillow\n# https://github.com/uploadcare/pillow-simd\npip uninstall pillow && \\\nCC=\"cc -mavx2\" pip install -U --force-reinstall pillow-simd\n\nspacy download en\n\npip install -r requirements.txt\n\ngit clone https://github.com/NVIDIA/apex.git &&\\\n    cd apex &&\\\n    pip install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" . &&\\\n    rm -rf ../apex\n\n"
  },
  {
    "path": "env/requirements.txt",
    "content": "ipdb\njoblib\ncytoolz\nlz4==2.1.9\nlmdb==0.97\nmsgpack-numpy\nmsgpack\ntoolz\ntransformers==4.11.3\ntensorboard\ntqdm\neasydict\npycocotools>=2.0.1\nopencv-python\ntensorboardX==2.0\nav==8.0.2\nujson\neinops\ndecord\ntimm\n"
  },
  {
    "path": "run_scripts/clear_cuda_cache.sh",
    "content": "for i in $(lsof /dev/nvidia* | grep python  | awk '{print $2}' | sort -u); do kill -9 $i; done\n\n\n\n"
  },
  {
    "path": "run_scripts/ft_didemo_ret.sh",
    "content": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/didemo_ret.json'\n\nhorovodrun -np 8 python src/tasks/run_video_retrieval.py \\\n      --config $CONFIG_PATH \\\n      --output_dir /export/home/workspace/experiments/alpro/finetune/didemo_ret/$(date '+%Y%m%d%H%M%S')\n"
  },
  {
    "path": "run_scripts/ft_msrvtt_qa.sh",
    "content": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/msrvtt_qa.json'\n\nhorovodrun -np 8 python src/tasks/run_video_qa.py \\\n      --config $CONFIG_PATH \\\n      --output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_qa/$(date '+%Y%m%d%H%M%S')\n"
  },
  {
    "path": "run_scripts/ft_msrvtt_ret.sh",
    "content": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/msrvtt_ret.json'\n\nhorovodrun -np 8 python src/tasks/run_video_retrieval.py \\\n      --config $CONFIG_PATH \\\n      --output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_ret/$(date '+%Y%m%d%H%M%S')"
  },
  {
    "path": "run_scripts/ft_msvd_qa.sh",
    "content": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/msvd_qa.json'\n\nhorovodrun -np 8 python src/tasks/run_video_qa.py \\\n      --config $CONFIG_PATH \\\n      --output_dir /export/home/workspace/experiments/alpro/finetune/msvd_qa/$(date '+%Y%m%d%H%M%S')\n"
  },
  {
    "path": "run_scripts/inf_didemo_ret.sh",
    "content": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nSTEP='best'\n\nCONFIG_PATH='config_release/didemo_ret.json'\n\nTXT_DB='data/didemo_ret/txt/test.jsonl'\nIMG_DB='data/didemo_ret/videos'\n\nhorovodrun -np 8 python src/tasks/run_video_retrieval.py \\\n      --do_inference 1 \\\n      --inference_split test \\\n      --inference_model_step $STEP \\\n      --inference_txt_db $TXT_DB \\\n      --inference_img_db $IMG_DB \\\n      --inference_batch_size 64 \\\n      --output_dir output/downstreams/didemo_ret/public \\\n      --config $CONFIG_PATH"
  },
  {
    "path": "run_scripts/inf_msrvtt_qa.sh",
    "content": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nSTEP='best'\n\nCONFIG_PATH='config_release/msrvtt_qa.json'\n\nTXT_DB='data/msrvtt_qa/txt/test.jsonl'\nIMG_DB='data/msrvtt_qa/videos'\n\nhorovodrun -np 8 python src/tasks/run_video_qa.py \\\n      --do_inference 1 \\\n      --inference_split test \\\n      --inference_model_step $STEP \\\n      --inference_txt_db $TXT_DB \\\n      --inference_img_db $IMG_DB \\\n      --inference_batch_size 64 \\\n      --output_dir output/downstreams/msrvtt_qa/public \\\n      --config $CONFIG_PATH"
  },
  {
    "path": "run_scripts/inf_msrvtt_ret.sh",
    "content": "cd ..\n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nSTEP='best'\n\nCONFIG_PATH='config_release/msrvtt_ret.json'\n\nTXT_DB='data/msrvtt_ret/txt/test.jsonl'\nIMG_DB='data/msrvtt_ret/videos'\n\nhorovodrun -np 8 python src/tasks/run_video_retrieval.py \\\n      --do_inference 1 \\\n      --inference_split test \\\n      --inference_model_step $STEP \\\n      --inference_txt_db $TXT_DB \\\n      --inference_img_db $IMG_DB \\\n      --inference_batch_size 64 \\\n      --output_dir  output/downstreams/msrvtt_ret/public \\\n      --config $CONFIG_PATH"
  },
  {
    "path": "run_scripts/inf_msvd_qa.sh",
    "content": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nSTEP='best'\n\nCONFIG_PATH='config_release/msvd_qa.json'\n\nTXT_DB='data/msvd_qa/txt/test.jsonl'\nIMG_DB='data/msvd_qa/videos'\n\nhorovodrun -np 8 python src/tasks/run_video_qa.py \\\n      --do_inference 1 \\\n      --inference_split test \\\n      --inference_model_step $STEP \\\n      --inference_txt_db $TXT_DB \\\n      --inference_img_db $IMG_DB \\\n      --inference_batch_size 64 \\\n      --output_dir output/downstreams/msvd_qa/public \\\n      --config $CONFIG_PATH"
  },
  {
    "path": "run_scripts/pt_alpro.sh",
    "content": "cd ..\n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/pretrain_alpro.json'\n\nhorovodrun -np 16 python src/pretrain/run_pretrain_sparse.py \\\n      --config $CONFIG_PATH \\\n      --output_dir /export/home/workspace/experiments/alpro/vl/$(date '+%Y%m%d%H%M%S')"
  },
  {
    "path": "run_scripts/pt_prompter.sh",
    "content": "cd .. \n\nexport PYTHONPATH=\"$PYTHONPATH:$PWD\"\necho $PYTHONPATH\n\nCONFIG_PATH='config_release/pretrain_prompter.json'\n\nhorovodrun -np 8 python src/pretrain/run_pretrain_contrastive_only.py \\\n      --config $CONFIG_PATH \\\n      --output_dir /export/home/workspace/experiments/alpro/prompter/$(date '+%Y%m%d%H%M%S')"
  },
  {
    "path": "src/__init__.py",
    "content": ""
  },
  {
    "path": "src/configs/config.py",
    "content": "\"\"\"\nModified from UNITER code\n\"\"\"\nimport os\nimport sys\nimport json\nimport argparse\n\nfrom easydict import EasyDict as edict\n\n\ndef parse_with_config(parsed_args):\n    \"\"\"This function will set args based on the input config file.\n    (1) it only overwrites unset parameters,\n        i.e., these parameters not set from user command line input\n    (2) it also sets configs in the config file but declared in the parser\n    \"\"\"\n    # convert to EasyDict object, enabling access from attributes even for nested config\n    # e.g., args.train_datasets[0].name\n    args = edict(vars(parsed_args))\n    if args.config is not None:\n        config_args = json.load(open(args.config))\n        override_keys = {arg[2:].split(\"=\")[0] for arg in sys.argv[1:]\n                         if arg.startswith(\"--\")}\n        for k, v in config_args.items():\n            if k not in override_keys:\n                setattr(args, k, v)\n    del args.config\n    return args\n\n\nclass SharedConfigs(object):\n    \"\"\"Shared options for pre-training and downstream tasks.\n    For each downstream task, implement a get_*_args function,\n    see `get_pretraining_args()`\n\n    Usage:\n    >>> shared_configs = SharedConfigs()\n    >>> pretraining_config = shared_configs.get_pretraining_args()\n    \"\"\"\n\n    def __init__(self, desc=\"shared config for pretraining and finetuning\"):\n        parser = argparse.ArgumentParser(description=desc)\n        # debug parameters\n        parser.add_argument(\n            \"--debug\", type=int, choices=[0, 1], default=0,\n            help=\"debug mode, output extra info & break all loops.\"\n                 \"0: disable, 1 enable\")\n        parser.add_argument(\n            \"--data_ratio\", type=float, default=1.0,\n            help=\"portion of train/val exampels to use,\"\n                 \"e.g., overfit a small set of data\")\n\n        # Required parameters\n        parser.add_argument(\n            \"--model_config\", type=str,\n            help=\"path to model structure config json\")\n        parser.add_argument(\n            \"--tokenizer_dir\", type=str, help=\"path to tokenizer dir\")\n        parser.add_argument(\n            \"--output_dir\", type=str,\n            help=\"dir to store model checkpoints & training meta.\")\n\n        # data preprocessing parameters\n        parser.add_argument(\n            \"--max_txt_len\", type=int, default=20, help=\"max text #tokens \")\n        # parser.add_argument(\n        #     \"--max_img_size\", type=int, default=448,\n        #     help=\"max image longer side size, shorter side will be padded with zeros\")\n        parser.add_argument(\n            \"--img_pixel_mean\", type=float, default=None,\n            nargs=3, help=\"image pixel mean\")\n        parser.add_argument(\n            \"--img_pixel_std\", type=float, default=None,\n            nargs=3, help=\"image pixel std\")\n        parser.add_argument(\n            \"--img_input_format\", type=str, default=\"BGR\",\n            choices=[\"BGR\", \"RGB\"], help=\"image input format is BGR for detectron2\")\n        parser.add_argument(\n            \"--max_n_example_per_group\", type=int, default=1,\n            help=\"max #examples (e.g., captions) paired with each image/video in an input group.\"\n                 \"1: each image is paired with a single sent., equivalent to sample by sent.;\"\n                 \"X (X>1): each image can be paired with a maximum of X sent.; X>1 can be used \"\n                 \"to reduce image processing time, including basic transform (resize, etc) and CNN encoding\"\n        )\n        # video specific parameters\n        parser.add_argument(\"--fps\", type=int, default=1, help=\"video frame rate to use\")\n        parser.add_argument(\"--num_frm\", type=int, default=3,\n                            help=\"#frames to use per clip -- we first sample a clip from a video, \"\n                                 \"then uniformly sample num_frm from the clip. The length of the clip \"\n                                 \"will be fps * num_frm\")\n        parser.add_argument(\"--frm_sampling_strategy\", type=str, default=\"rand\",\n                            choices=[\"rand\", \"uniform\", \"start\", \"middle\", \"end\"],\n                            help=\"see src.datasets.dataset_base.extract_frames_from_video_binary for details\")\n\n        # MLL training settings\n        parser.add_argument(\"--train_n_clips\", type=int, default=3,\n                            help=\"#clips to sample from each video for MIL training\")\n        parser.add_argument(\"--score_agg_func\", type=str, default=\"mean\",\n                            choices=[\"mean\", \"max\", \"lse\"],\n                            help=\"score (from multiple clips) aggregation function, lse = LogSumExp\")\n        parser.add_argument(\"--random_sample_clips\", type=int, default=1, choices=[0, 1],\n                            help=\"randomly sample clips for training, otherwise use uniformly sampled clips.\")\n\n        # training parameters\n        parser.add_argument(\n            \"--train_batch_size\", default=128, type=int,\n            help=\"Single-GPU batch size for training for Horovod.\")\n        parser.add_argument(\n            \"--val_batch_size\", default=128, type=int,\n            help=\"Single-GPU batch size for validation for Horovod.\")\n        parser.add_argument(\n            \"--gradient_accumulation_steps\", type=int, default=1,\n            help=\"#updates steps to accumulate before performing a backward/update pass.\"\n                 \"Used to simulate larger batch size training. The simulated batch size \"\n                 \"is train_batch_size * gradient_accumulation_steps for a single GPU.\")\n        parser.add_argument(\"--learning_rate\", default=5e-5, type=float,\n                            help=\"initial learning rate.\")\n        parser.add_argument(\n            \"--log_interval\", default=500, type=int,\n            help=\"record every a few steps on tensorboard.\")\n        parser.add_argument(\n            \"--num_valid\", default=20, type=int,\n            help=\"Run validation X times during training and checkpoint.\")\n        parser.add_argument(\n            \"--min_valid_steps\", default=100, type=int,\n            help=\"minimum #steps between two validation runs\")\n        parser.add_argument(\n            \"--save_steps_ratio\", default=0.01, type=float,\n            help=\"save every 0.01*global steps to resume after preemption,\"\n                 \"not used for checkpointing.\")\n        parser.add_argument(\"--num_train_epochs\", default=10, type=int,\n                            help=\"Total #training epochs.\")\n        parser.add_argument(\"--optim\", default=\"adamw\",\n                            choices=[\"adam\", \"adamax\", \"adamw\"],\n                            help=\"optimizer\")\n        parser.add_argument(\"--betas\", default=[0.9, 0.98],\n                            nargs=2, help=\"beta for adam optimizer\")\n        parser.add_argument(\"--decay\", default=\"linear\",\n                            choices=[\"linear\", \"invsqrt\"],\n                            help=\"learning rate decay method\")\n        parser.add_argument(\"--dropout\", default=0.1, type=float,\n                            help=\"tune dropout regularization\")\n        parser.add_argument(\"--weight_decay\", default=1e-3, type=float,\n                            help=\"weight decay (L2) regularization\")\n        parser.add_argument(\"--grad_norm\", default=2.0, type=float,\n                            help=\"gradient clipping (-1 for no clipping)\")\n        parser.add_argument(\n            \"--warmup_ratio\", default=0.1, type=float,\n            help=\"to perform linear learning rate warmup for. (invsqrt decay)\")\n        parser.add_argument(\"--transformer_lr_mul\", default=1.0, type=float,\n                            help=\"lr_mul for transformer\")\n        parser.add_argument(\"--step_decay_epochs\", type=int,\n                            nargs=\"+\", help=\"multi_step decay epochs\")\n        # model arch\n        parser.add_argument(\n            \"--model_type\", type=str, default=\"pretrain\",\n            help=\"type of e2e model to use. Support only 'pretrain' for now. \")\n        parser.add_argument(\n            \"--timesformer_model_cfg\", type=str, default=\"\",\n            help=\"path to timesformer model cfg yaml\")\n\n        # checkpoint\n        parser.add_argument(\"--e2e_weights_path\", type=str,\n                            help=\"path to e2e model weights\")\n        parser.add_argument(\n            \"--clip_init\", default=0, type=int, choices=[0, 1],\n            help=\"1 for using clip ckpt for init.\")\n        parser.add_argument(\"--bert_weights_path\", type=str,\n                            help=\"path to BERT weights, only use for pretraining\")\n\n        # inference only, please include substring `inference'\n        # in the option to avoid been overwrite by loaded options,\n        # see start_inference() in run_vqa_w_hvd.py\n        parser.add_argument(\"--inference_model_step\", default=-1, type=str,\n                            help=\"pretrained model checkpoint step\")\n        parser.add_argument(\n            \"--do_inference\", default=0, type=int, choices=[0, 1],\n            help=\"perform inference run. 0: disable, 1 enable\")\n        parser.add_argument(\n            \"--inference_split\", default=\"val\",\n            help=\"For val, the data should have ground-truth associated it.\"\n                 \"For test*, the data comes with no ground-truth.\")\n        parser.add_argument(\"--inference_txt_db\", type=str,\n                            help=\"path to txt_db file for inference\")\n        parser.add_argument(\"--inference_img_db\", type=str,\n                            help=\"path to img_db file for inference\")\n        parser.add_argument(\"--inference_batch_size\", type=int, default=64,\n                            help=\"single-GPU batch size for inference\")\n        parser.add_argument(\"--inference_n_clips\", type=int, default=1,\n                            help=\"uniformly sample `ensemble_n_clips` clips, \"\n                                 \"each contains `num_frm` frames. When it == 1, \"\n                                 \"use the frm_sampling_strategy to sample num_frm frames.\"\n                                 \"When it > 1, ignore frm_sampling_strategy, \"\n                                 \"uniformly sample N clips, each clips num_frm frames.\")\n\n        # device parameters\n        parser.add_argument(\"--seed\", type=int, default=42,\n                            help=\"random seed for initialization\")\n        parser.add_argument(\n            \"--fp16\", type=int, choices=[0, 1], default=0,\n            help=\"Use 16-bit float precision instead of 32-bit.\"\n                 \"0: disable, 1 enable\")\n        parser.add_argument(\"--n_workers\", type=int, default=4,\n                            help=\"#workers for data loading\")\n        parser.add_argument(\"--pin_mem\", type=int, choices=[0, 1], default=1,\n                            help=\"pin memory. 0: disable, 1 enable\")\n\n        # can use config files, will only overwrite unset parameters\n        parser.add_argument(\"--config\", help=\"JSON config files\")\n        self.parser = parser\n\n    def parse_args(self):\n        parsed_args = self.parser.parse_args()\n        args = parse_with_config(parsed_args)\n\n        # convert to all [0, 1] options to bool, including these task specific ones\n        zero_one_options = [\n            \"fp16\", \"pin_mem\", \"use_itm\", \"use_mlm\", \"use_itc\", \"debug\", #\"freeze_cnn\",\n            \"do_inference\",\n        ]\n        for option in zero_one_options:\n            if hasattr(args, option):\n                setattr(args, option, bool(getattr(args, option)))\n\n        # basic checks\n        # This is handled at TrainingRestorer\n        # if exists(args.output_dir) and os.listdir(args.output_dir):\n        #     raise ValueError(f\"Output directory ({args.output_dir}) \"\n        #                      f\"already exists and is not empty.\")\n        if args.step_decay_epochs and args.decay != \"multi_step\":\n            Warning(\n                f\"--step_decay_epochs epochs set to {args.step_decay_epochs}\"\n                f\"but will not be effective, as --decay set to be {args.decay}\")\n\n        assert args.gradient_accumulation_steps >= 1, \\\n            f\"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps} \"\n\n        assert 1 >= args.data_ratio > 0, \\\n            f\"--data_ratio should be [1.0, 0), but get {args.data_ratio}\"\n\n        return args\n\n    def get_sparse_pretraining_args(self):\n        # pre-training args\n        self.parser.add_argument(\n            \"--use_itm\", type=int, choices=[0, 1], default=0,\n            help=\"enable itm loss. 0: disable, 1 enable\")\n        self.parser.add_argument(\n            \"--use_mlm\", type=int, choices=[0, 1], default=0,\n            help=\"enable mlm loss. 0: disable, 1 enable\")\n        self.parser.add_argument(\n            \"--use_itc\", type=int, choices=[0, 1], default=0,\n            help=\"enable itc loss. 0: disable, 1 enable\")\n        \n        # sparse pretraining-specific settings\n        self.parser.add_argument(\n            \"--crop_img_size\", type=int, default=256,\n            help=\"crop size during pre-training.\")\n        self.parser.add_argument(\n            \"--resize_size\", type=int, default=288,\n            help=\"resize frames to square, ignoring aspect ratio.\")\n\n        # MPM-specific\n        self.parser.add_argument(\n            \"--use_mpm\", type=int, choices=[0, 1], default=0,\n            help=\"enable mpm loss. 0: disable, 1 enable\")\n        self.parser.add_argument(\"--teacher_weights_path\", type=str,\n                            help=\"path to teacher model weights, only use for pretraining.\")\n        self.parser.add_argument(\"--entity_file_path\", type=str,\n                            help=\"path to selected NOUN entities.\")\n        self.parser.add_argument(\n            \"--num_entities\", type=int, default=1000,\n            help=\"maximum entities to consider for pseudo labels.\")\n\n        args = self.parse_args()\n        return args\n\n    def get_video_retrieval_args(self):\n        self.parser.add_argument(\"--eval_retrieval_batch_size\", type=int, default=256,\n                                 help=\"batch size for retrieval, since each batch will only have one image, \"\n                                      \"retrieval allows larger batch size\")\n\n        args = self.parse_args()\n        return args\n\n    def get_nlvl_args(self):\n        args = self.parse_args()\n\n        return args\n\n\n    def get_vqa_args(self):\n        self.parser.add_argument(\"--ans2label_path\", type=str,\n                                 help=\"path to {answer: label} file\")\n        self.parser.add_argument(\"--loss_type\", type=str, default=\"bce\",\n                                 help=\"loss type\")\n        self.parser.add_argument(\"--classifier\", type=str, default=\"mlp\",\n                                 choices=[\"mlp\", \"linear\"],\n                                 help=\"classifier type\")\n        self.parser.add_argument(\n            \"--cls_hidden_scale\", type=int, default=2,\n            help=\"scaler of the intermediate linear layer dimension for mlp classifier\")\n        self.parser.add_argument(\"--num_labels\", type=int, default=3129,\n                                 help=\"#labels/output-dim for classifier\")\n        return self.parse_args()\n\n    def get_video_qa_args(self):\n        self.parser.add_argument(\n            \"--task\", type=str,\n            choices=[\"action\", \"transition\", \"frameqa\", \"msrvtt_qa\"],\n            help=\"TGIF-QA tasks and MSRVTT-QA\")\n        self.parser.add_argument(\"--loss_type\", type=str, default=\"ce\",\n                                 help=\"loss type, will be overwritten later\")\n        self.parser.add_argument(\"--classifier\", type=str, default=\"mlp\",\n                                 choices=[\"mlp\", \"linear\"],\n                                 help=\"classifier type\")\n        self.parser.add_argument(\n            \"--cls_hidden_scale\", type=int, default=2,\n            help=\"scaler of the intermediate linear layer dimension for mlp classifier\")\n        # for frameQA msrvtt_qa\n        self.parser.add_argument(\"--ans2label_path\", type=str, default=None,\n                                 help=\"path to {answer: label} file\")\n\n        # manually setup config by task type\n        args = self.parse_args()\n        if args.max_n_example_per_group != 1:\n            Warning(f\"For TGIF-QA, most GIF is only paired with a single example, no need to\"\n                    f\"use max_n_example_per_group={args.max_n_example_per_group}\"\n                    f\"larger than 1. Automatically reset to 1.\")\n            args.max_n_example_per_group = 1\n        if os.path.exists(args.ans2label_path):\n            num_answers = len(json.load(open(args.ans2label_path, \"r\")))\n        else:\n            num_answers = 0\n\n        if args.task in [\"msrvtt_qa\", \"msvd_qa\"]:\n            args.num_labels = max(num_answers, 1500)\n            args.loss_type = \"ce\"\n        else:\n            raise NotImplementedError\n        return args\n\n\nshared_configs = SharedConfigs()\n"
  },
  {
    "path": "src/datasets/data_utils.py",
    "content": "import torch\nimport random\nimport torchvision.transforms as transforms\nfrom torchvision.transforms.functional import pad as img_pad\nfrom torchvision.transforms.functional import resize as img_resize\nfrom torch.nn.functional import interpolate as img_tensor_resize\nfrom torch.nn.functional import pad as img_tensor_pad\nfrom torch.nn.modules.utils import _quadruple\nfrom src.utils.basic_utils import flat_list_of_lists\nimport numbers\nimport numpy as np\nfrom PIL import Image\n_pil_interpolation_to_str = {\n    Image.NEAREST: 'PIL.Image.NEAREST',\n    Image.BILINEAR: 'PIL.Image.BILINEAR',\n    Image.BICUBIC: 'PIL.Image.BICUBIC',\n    Image.LANCZOS: 'PIL.Image.LANCZOS',\n    Image.HAMMING: 'PIL.Image.HAMMING',\n    Image.BOX: 'PIL.Image.BOX',\n}\n\n\ndef mask_batch_text_tokens(\n        inputs, tokenizer, mlm_probability=0.15, is_train=True):\n    \"\"\" modified from transformers.data.data_collator\n    Args:\n        inputs: (B, L), 2D torch.Tensor, does not work for 1D. It has already been padded.\n        tokenizer:\n        mlm_probability: float\n        is_train: if True use random masking, else mask tokens at fixed position to remove randomness in evaluation.\n    \"\"\"\n    if tokenizer.mask_token is None:\n        raise ValueError(\n            \"This tokenizer does not have a mask token which is necessary for masked language modeling. \"\n            \"Remove the --mlm flag if you want to use this tokenizer.\"\n        )\n\n    labels = inputs.clone()\n    # We sample a few tokens in each sequence for masked-LM training\n    # (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)\n    probability_matrix = torch.full(labels.shape, mlm_probability)\n    special_tokens_mask = [\n        tokenizer.get_special_tokens_mask(\n            val, already_has_special_tokens=True) for val in labels.tolist()\n    ]\n    probability_matrix.masked_fill_(torch.tensor(\n        special_tokens_mask, dtype=torch.bool), value=0.0)\n    if tokenizer._pad_token is not None:\n        padding_mask = labels.eq(tokenizer.pad_token_id)\n        probability_matrix.masked_fill_(padding_mask, value=0.0)\n    masked_indices = torch.bernoulli(probability_matrix).bool()\n    labels[~masked_indices] = -100  # We only compute loss on masked tokens\n\n    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])\n    indices_replaced = torch.bernoulli(\n        torch.full(labels.shape, 0.8)).bool() & masked_indices\n    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(\n        tokenizer.mask_token)\n\n    # 10% of the time, we replace masked input tokens with random word\n    indices_random = torch.bernoulli(\n        torch.full(labels.shape, 0.5)\n        ).bool() & masked_indices & ~indices_replaced\n    random_words = torch.randint(\n        len(tokenizer), labels.shape,\n        dtype=torch.long)  # len(tokenizer) == #vocab\n    inputs[indices_random] = random_words[indices_random]\n\n    # The rest of the time (10% of the time) we keep the masked input tokens unchanged\n    return inputs, labels\n\n\ndef select_batch_text_pivots(\n        inputs, tokenizer, ent2id, mpm_probability=1.0, is_train=True):\n    \"\"\" Given a input text sequence, generate masks and prototype labels such that:\n    1) not to mask special token ([CLS], [SEP], [MASK], [PAD]);\n    2) always mask all BPE in a word together.\n\n    Args:\n    \"\"\"\n    if tokenizer.mask_token is None:\n        raise ValueError(\n            \"This tokenizer does not have a mask token which is necessary for masked language modeling. \"\n            \"Remove the --mlm flag if you want to use this tokenizer.\"\n        )\n\n    labels = inputs.clone()\n    # We sample a few tokens in each sequence for as pivots\n    probability_matrix = torch.full(labels.shape, mpm_probability)\n    # ignore [CLS] [SEP] [MASK] tokens\n    special_tokens_mask = [\n        tokenizer.get_special_tokens_mask(\n            val, already_has_special_tokens=True) for val in labels.tolist()\n    ]\n    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)\n    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)\n    # ignore [PAD] tokens\n    if tokenizer._pad_token is not None:\n        padding_mask = labels.eq(tokenizer.pad_token_id)\n        probability_matrix.masked_fill_(padding_mask, value=0.0)\n    \n    # create masking indices\n    pivot_indices = torch.bernoulli(probability_matrix).bool()\n    labels[special_tokens_mask] = -100  # We only compute loss on masked tokens\n    labels[~pivot_indices] = -100  # We only compute loss on masked tokens\n\n    # selected pivot positions: (1) non-special token; (2) selected based on mpm probability. \n    text_pivots_pos = (labels > 0).nonzero()\n    \n    for tpp in text_pivots_pos:\n\n        orig_tpp = tpp.clone()\n        \n        bth = tpp[0]\n        orig_text_pos = orig_tpp[1]\n\n        text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]]])[0]\n        next_text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]+1]])[0] if tpp[1]+1 < inputs.shape[1] else None\n\n        # TODO may consider support encoding beyond sentencepiece. \n        if text_token.startswith('##'):\n            # if it is a byte pair, backtrace until we find the prefix\n            orig_text_token = ''\n\n            while True:\n                if not text_token.startswith('##'):\n                    orig_text_token = text_token + orig_text_token\n\n                    break\n                else:\n                    orig_text_token = text_token[2:] + orig_text_token\n\n                    tpp[1] -= 1\n\n                text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]]])[0]\n            \n            try:\n                # assign prototype labels to all the sentencepiece bytes in the pivot word\n                labels[bth][tpp[1]: orig_text_pos + 1] = ent2id[orig_text_token]\n                pivot_indices[bth][tpp[1]: orig_text_pos + 1] = True\n            except KeyError:\n                # we do not have this word for prototype\n                labels[bth][orig_text_pos] = -100\n\n        elif next_text_token is not None and next_text_token.startswith('##'):\n            # if it is a prefix, forward-trace until we find the end of the byte pair\n            full_text_token = text_token \n\n            while True:\n                tpp[1] += 1\n                text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]]])[0]\n\n                if not text_token.startswith('##'):\n                    # find the next prefix/word\n                    break\n                else:\n                    # find continuing bytes\n                    full_text_token = full_text_token + text_token[2:]\n\n            try:\n                # assign prototype labels to all the sentencepiece bytes in the pivot word\n                labels[bth][orig_text_pos: tpp[1]] = ent2id[full_text_token]\n                pivot_indices[bth][orig_text_pos: tpp[1]] = True\n            except KeyError:\n                # we do not have this word for prototype\n                labels[bth][orig_text_pos] = -100\n\n        else:\n            # the word is treated in whole by BERT tokenizer\n            try:\n                labels[bth][tpp[1]] = ent2id[text_token]\n            except KeyError:\n                # we do not have this word for prototype\n                labels[bth][tpp[1]] = -100\n\n    # restore mask if the word is not in the entity list\n    pivot_indices[labels==-100] = False\n\n    return pivot_indices, labels\n\n\ndef image_to_tensor(image: np.ndarray, keepdim: bool = True) -> torch.Tensor:\n    \"\"\"Converts a numpy image to a PyTorch 4d tensor image.\n    Args:\n        image (numpy.ndarray): image of the form :math:`(H, W, C)`, :math:`(H, W)` or\n            :math:`(B, H, W, C)`.\n        keepdim (bool): If ``False`` unsqueeze the input image to match the shape\n            :math:`(B, H, W, C)`. Default: ``True``\n    Returns:\n        torch.Tensor: tensor of the form :math:`(B, C, H, W)` if keepdim is ``False``,\n            :math:`(C, H, W)` otherwise.\n    \"\"\"\n    if not isinstance(image, (np.ndarray,)):\n        raise TypeError(\"Input type must be a numpy.ndarray. Got {}\".format(\n            type(image)))\n\n    if len(image.shape) > 4 or len(image.shape) < 2:\n        raise ValueError(\n            \"Input size must be a two, three or four dimensional array\")\n\n    input_shape = image.shape\n    tensor: torch.Tensor = torch.from_numpy(image)\n\n    if len(input_shape) == 2:\n        # (H, W) -> (1, H, W)\n        tensor = tensor.unsqueeze(0)\n    elif len(input_shape) == 3:\n        # (H, W, C) -> (C, H, W)\n        tensor = tensor.permute(2, 0, 1)\n    elif len(input_shape) == 4:\n        # (B, H, W, C) -> (B, C, H, W)\n        tensor = tensor.permute(0, 3, 1, 2)\n        keepdim = True  # no need to unsqueeze\n    else:\n        raise ValueError(\n            \"Cannot process image with shape {}\".format(input_shape))\n\n    return tensor.unsqueeze(0) if not keepdim else tensor\n\n\ndef get_padding(image, max_w, max_h, pad_all=False):\n    # keep the images to upper-left corner\n    if isinstance(image, torch.Tensor):\n        h, w = image.shape[-2:]\n    else:\n        w, h = image.size\n    h_padding, v_padding = max_w - w, max_h - h\n    if pad_all:\n        h_padding /= 2\n        v_padding /= 2\n        l_pad = h_padding if h_padding % 1 == 0 else h_padding+0.5\n        t_pad = v_padding if v_padding % 1 == 0 else v_padding+0.5\n        r_pad = h_padding if h_padding % 1 == 0 else h_padding-0.5\n        b_pad = v_padding if v_padding % 1 == 0 else v_padding-0.5\n    else:\n        l_pad, t_pad = 0, 0\n        r_pad, b_pad = h_padding, v_padding\n    if isinstance(image, torch.Tensor):\n        padding = (int(l_pad), int(r_pad), int(t_pad), int(b_pad))\n    else:\n        padding = (int(l_pad), int(t_pad), int(r_pad), int(b_pad))\n    return padding\n\n\nclass ImagePad(object):\n    def __init__(self, max_w, max_h, fill=0, padding_mode='constant'):\n        assert isinstance(fill, (numbers.Number, str, tuple))\n        assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']\n        self.max_w = max_w\n        self.max_h = max_h\n        self.fill = fill\n        self.padding_mode = padding_mode\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (PIL Image): Image to be padded.\n\n        Returns:\n            PIL Image: Padded image.\n        \"\"\"\n        if isinstance(img, torch.Tensor):\n            paddings = _quadruple(get_padding(img, self.max_w, self.max_h))\n            return img_tensor_pad(\n                img, paddings,\n                self.padding_mode, self.fill)\n        return img_pad(\n            img, get_padding(img, self.max_w, self.max_h),\n            self.fill, self.padding_mode)\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\\\n            format(self.fill, self.padding_mode)\n\n\ndef get_resize_size(image, max_size):\n    \"\"\"\n    Args:\n        image: PIL Image or torch.tensor\n        max_size:\n\n    Returns:\n\n    Note the height/width order difference\n    >>> pil_img = Image.open(\"raw_img_tensor.jpg\")\n    >>> pil_img.size\n    (640, 480)  # (width, height)\n    >>> np_img = np.array(pil_img)\n    >>> np_img.shape\n    (480, 640, 3)  # (height, width, 3)\n    \"\"\"\n    # note the order of height and width for different inputs\n    if isinstance(image, torch.Tensor):\n        # width, height = image.shape[-2:]\n        height, width = image.shape[-2:]\n    else:\n        width, height = image.size\n\n    if height >= width:\n        ratio = width*1./height\n        new_height = max_size\n        new_width = new_height * ratio\n    else:\n        ratio = height*1./width\n        new_width = max_size\n        new_height = new_width * ratio\n    size = (int(new_height), int(new_width))\n    return size\n\nclass VideoRandomSquareCrop(object):\n    def __init__(self, crop_size, p=0.5):\n        assert isinstance(crop_size, int)\n        self.crop_size = crop_size\n        self.p = p\n\n    def __call__(self, video):\n        \"\"\"\n        Args:\n            img (torch.tensor): video to be cropped.\n\n        Returns:\n            torch.tensor: cropped video.\n        \"\"\"\n        if isinstance(video, torch.Tensor):\n            if len(video.shape) == 4:\n                b, t, h, w = video.shape\n            else:\n                raise RuntimeError('Expecting 4-dimensional tensor of shape (b,t,h,w), got {}'.format(video.shape))\n\n            # if random.uniform(0, 1) < self.p:\n            #     video = torch.flip(video, (3,))\n\n            x = random.randint(0, h - self.crop_size)\n            y = random.randint(0, w - self.crop_size)\n\n            return video[:, :, x: x + self.crop_size, y: y + self.crop_size]\n\n        else:\n            raise NotImplementedError('Support only torch.Tensor as input, got {}'.format(type(video)))\n\n\nclass VideoResizeSquare(object):\n    def __init__(self, out_size, interpolation='nearest'):\n        assert isinstance(out_size, int)\n        self.out_size = out_size\n        self.interpolation = interpolation\n\n    def __call__(self, video):\n        \"\"\"\n        Args:\n            img (torch.tensor): video to be scaled.\n\n        Returns:\n            torch.tensor: Rescaled video.\n        \"\"\"\n        if isinstance(video, torch.Tensor):\n            if len(video.shape) == 4:\n                t, c, h, w = video.shape\n                assert c == 3, 'Expecting 3-channel color video, got video of shape {}'.format(video.shape)\n            else:\n                raise RuntimeError('Expecting 4-dimensional tensor of shape (b,t,h,w), got {}'.format(video.shape))\n\n            short_side = h if h < w else w\n            # scaling_factor = self.out_size / short_side\n\n            # new_h = int(h * scaling_factor)\n            # new_w = int(w * scaling_factor)\n\n            resized_video = img_tensor_resize(video, size=((self.out_size, self.out_size)), mode=self.interpolation)\n            \n            return resized_video\n\n\n        else:\n            # in other data class, the order of shape might be different.\n            raise NotImplementedError('Support only torch.Tensor as input, got {}'.format(type(video)))\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(\n            self.out_size, self.interpolation)\n\n\nclass ImageResize(object):\n    \"\"\"Resize the input image (torch.tensor) to the given size.\n\n    Args:\n        max_size (int): Desired output size. If size is a sequence like\n            (h, w), output size will be matched to this. If size is an int,\n            smaller edge of the image will be matched to this number.\n            i.e, if height > width, then image will be rescaled to\n            (size * height / width, size)\n        interpolation (int, optional): Desired interpolation. Default is\n            ``PIL.Image.BILINEAR``\n    \"\"\"\n\n    def __init__(self, max_size, interpolation=Image.BILINEAR):\n        assert isinstance(max_size, int)\n        self.max_size = max_size\n        self.interpolation = interpolation\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (torch.tensor): Image to be scaled.\n\n        Returns:\n            torch.tensor: Rescaled image.\n        \"\"\"\n        if isinstance(img, torch.Tensor):\n            assert isinstance(self.interpolation, str)\n            return img_tensor_resize(\n                img, size=get_resize_size(img, self.max_size),\n                mode=self.interpolation, align_corners=False)\n        return img_resize(\n            img, get_resize_size(img, self.max_size), self.interpolation)\n\n    def __repr__(self):\n        interpolate_str = _pil_interpolation_to_str[self.interpolation]\n        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(\n            self.size, interpolate_str)\n\n\ndef get_imagenet_transform(min_size=600, max_size=1000):\n    \"\"\"parameters from https://github.com/pytorch/examples/blob/master/imagenet/main.py\n    This simply crop the center square from the image\n    \"\"\"\n    if min_size != 600:\n        import warnings\n        warnings.warn(f'Warning: min_size is not used in image transform, '\n                      f'setting min_size will have no effect.')\n    return transforms.Compose([\n        ImageResize(max_size, Image.BILINEAR),  # longer side will be resized to 1000\n        ImagePad(max_size, max_size),  # pad to 1000 * 1000\n    ])\n\n\nclass ImageNorm(object):\n    \"\"\"Apply Normalization to Image Pixels on GPU\n    \"\"\"\n    def __init__(self, mean, std):\n        self.mean = torch.tensor(mean).cuda().view(1, 1, 3, 1, 1)\n        self.std = torch.tensor(std).cuda().view(1, 1, 3, 1, 1)\n        # assert max(std) <= 1 and min(std) >= 0\\\n        #     or max(mean) <= 1 and min(mean) >= 0,\\\n        #         \"Please provide mean or std within range [0, 1]\"\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img: float image tensors, (B, N, 3, H, W)\n\n        Returns:\n            img: normalized float image tensors\n        \"\"\"\n        if torch.max(img) > 1 and self.mean.max() <= 1:\n            img.div_(255.)\n        return img.sub_(self.mean).div_(self.std)\n\n\ndef chunk_list(examples, chunk_size=2, pad_to_divisible=True):\n    \"\"\"\n    Args:\n        examples: iterable, examples grouped by image/video\n        chunk_size: int, number of examples in each chunk.\n        pad_to_divisible: bool, pad the examples to be divisible by chunk_size.\n    >>> test_examples = [3, 4, 5, 6, 7]\n    >>> chunk_list(test_examples, chunk_size=2, pad_to_divisible=True)\n    [[3, 4], [5, 6], [7, 7]]  # the lst element has some randomness\n    >>> chunk_list(test_examples, chunk_size=2, pad_to_divisible=False)\n    [[3, 4], [5, 6], [7]]\n    \"\"\"\n    n_examples = len(examples)\n    remainder = n_examples % chunk_size\n    if pad_to_divisible and remainder > 0:\n        n_pad = chunk_size - remainder\n        pad = random.choices(examples, k=n_pad)  # with replacement\n        examples = examples + pad\n        n_examples = len(examples)\n        remainder = 0\n    chunked_examples = []\n    n_chunks = int(n_examples / chunk_size)\n    n_chunks = n_chunks + 1 if remainder > 0 else n_chunks\n    for i in range(n_chunks):\n        chunked_examples.append(examples[i*chunk_size: (i+1)*chunk_size])\n    return chunked_examples\n\n\ndef mk_input_group(key_grouped_examples, max_n_example_per_group=1, is_train=True,\n                   example_unique_key=None):\n    \"\"\" Re-organize examples into groups. Each input group will have a single image paired\n    with X (X=max_n_example_per_img) examples. Images with total #examples > X will be\n    split into multiple groups. In the case a group has < X examples, we will copy\n    the examples to make the group has X examples.\n    Args:\n        key_grouped_examples: dict, each key is image/video id,\n            each value is a list(example) associated with this image/video\n        max_n_example_per_group: int, pair max #examples with each image/video.\n           Note that each image can have multiple groups.\n        is_train: bool, if True, copy the examples to make sure each input\n            group has max_n_example_per_group examples.\n        example_unique_key: str, used to make sure no inputs are discarded by matching\n            the input and output ids specified by `example_unique_key`\n    \"\"\"\n    input_groups = []  # each element is (id, list(example))\n    for k, examples in key_grouped_examples.items():\n        chunked_examples = chunk_list(examples,\n                                      chunk_size=max_n_example_per_group,\n                                      pad_to_divisible=is_train)\n        for c in chunked_examples:\n            # if len(c) == 0:\n            #     continue\n            input_groups.append((k, c))\n\n    if example_unique_key is not None:\n        print(f\"Using example_unique_key {example_unique_key} to check whether input and output ids m\")\n        # sanity check: make sure we did not discard any input example by accident.\n        input_question_ids = flat_list_of_lists(\n            [[sub_e[example_unique_key] for sub_e in e] for e in key_grouped_examples.values()])\n        output_question_ids = flat_list_of_lists(\n            [[sub_e[example_unique_key] for sub_e in e[1]] for e in input_groups])\n        assert set(input_question_ids) == set(output_question_ids), \"You are missing \"\n    return input_groups\n\n\n# def repeat_tensor_rows(raw_tensor, row_repeats):\n#     \"\"\" repeat raw_tensor[i] row_repeats[i] times.\n#     Args:\n#         raw_tensor: (B, *)\n#         row_repeats: list(int), len(row_repeats) == len(raw_tensor)\n#     \"\"\"\n#     assert len(raw_tensor) == len(raw_tensor), \"Has to be the same length\"\n#     if sum(row_repeats) == len(row_repeats):\n#         return raw_tensor\n#     else:\n#         indices = torch.LongTensor(\n#             flat_list_of_lists([[i] * r for i, r in enumerate(row_repeats)])\n#         ).to(raw_tensor.device)\n#         return raw_tensor.index_select(0, indices)\n\n\n\n"
  },
  {
    "path": "src/datasets/dataloader.py",
    "content": "\"\"\"\nmodified from UNITER codebase\n\nA meta data loader for sampling from different datasets / training tasks\nA prefetch loader to speedup data loading\n\"\"\"\nimport random\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom src.utils.distributed import any_broadcast\n\n\nclass MetaLoader(object):\n    \"\"\" wraps multiple data loader \"\"\"\n    def __init__(self, loaders, accum_steps=1, distributed=False):\n        assert isinstance(loaders, dict)\n        self.name2loader = {}\n        self.name2iter = {}\n        self.sampling_pools = []\n        n_batches_in_epoch = 0\n        for n, l in loaders.items():\n            if isinstance(l, tuple):\n                l, r = l\n            elif isinstance(l, DataLoader):\n                r = 1\n            else:\n                raise ValueError()\n            n_batches_in_epoch += len(l.dataset) * r / l.batch_size\n            self.name2loader[n] = l\n            self.name2iter[n] = iter(l)\n            self.sampling_pools.extend([n]*r)\n        self.n_batches_in_epoch = n_batches_in_epoch\n        self.accum_steps = accum_steps\n        self.distributed = distributed\n        self.step = 0\n\n    def __iter__(self):\n        \"\"\" this iterator will run indefinitely \"\"\"\n        task = self.sampling_pools[0]\n        while True:\n            if self.step % self.accum_steps == 0:\n                task = random.choice(self.sampling_pools)\n                if self.distributed:\n                    # make sure all process is training same task\n                    task = any_broadcast(task, 0)\n            self.step += 1\n            iter_ = self.name2iter[task]\n            try:\n                batch = next(iter_)\n            except StopIteration:\n                iter_ = iter(self.name2loader[task])\n                batch = next(iter_)\n                self.name2iter[task] = iter_\n\n            yield task, batch\n\n\ndef move_to_cuda(batch):\n    if isinstance(batch, torch.Tensor):\n        return batch.cuda(non_blocking=True)\n    elif isinstance(batch, list):\n        new_batch = [move_to_cuda(t) for t in batch]\n    elif isinstance(batch, tuple):\n        new_batch = tuple(move_to_cuda(t) for t in batch)\n    elif isinstance(batch, dict):\n        new_batch = {n: move_to_cuda(t) for n, t in batch.items()}\n    else:\n        return batch\n    return new_batch\n\n\ndef record_cuda_stream(batch):\n    if isinstance(batch, torch.Tensor):\n        batch.record_stream(torch.cuda.current_stream())\n    elif isinstance(batch, list) or isinstance(batch, tuple):\n        for t in batch:\n            record_cuda_stream(t)\n    elif isinstance(batch, dict):\n        for t in batch.values():\n            record_cuda_stream(t)\n    else:\n        pass\n\n\nclass PrefetchLoader(object):\n    \"\"\"\n    overlap compute and cuda data transfer\n    (copied and then modified from nvidia apex)\n    \"\"\"\n    def __init__(self, loader, img_normalize=None):\n        self.loader = loader\n        self.stream = torch.cuda.Stream()\n        self.img_normalize = img_normalize\n\n    def __iter__(self):\n        loader_it = iter(self.loader)\n        self.preload(loader_it)\n        batch = self.next(loader_it)\n        while batch is not None:\n            is_tuple = isinstance(batch, tuple)\n            if is_tuple:\n                task, batch = batch\n            batch[\"visual_inputs\"] = batch[\"visual_inputs\"].float()\n            if self.img_normalize is not None:\n                batch[\"visual_inputs\"] = self.img_normalize(\n                    batch[\"visual_inputs\"])\n                if \"crop_visual_inputs\" in batch:\n                    batch[\"crop_visual_inputs\"] = batch[\"crop_visual_inputs\"].float()\n                    batch[\"crop_visual_inputs\"] = self.img_normalize(\n                        batch[\"crop_visual_inputs\"])\n                if \"context_visual_inputs\" in batch:\n                    batch[\"context_visual_inputs\"] = batch[\"context_visual_inputs\"].float()\n                    batch[\"context_visual_inputs\"] = self.img_normalize(\n                        batch[\"context_visual_inputs\"])\n            if is_tuple:\n                yield task, batch\n            else:\n                yield batch\n            batch = self.next(loader_it)\n\n    def __len__(self):\n        return len(self.loader)\n\n    def preload(self, it):\n        try:\n            self.batch = next(it)\n        except StopIteration:\n            self.batch = None\n            return\n        # if record_stream() doesn't work, another option is to make sure\n        # device inputs are created on the main stream.\n        # self.next_input_gpu = torch.empty_like(self.next_input,\n        #                                        device='cuda')\n        # self.next_target_gpu = torch.empty_like(self.next_target,\n        #                                         device='cuda')\n        # Need to make sure the memory allocated for next_* is not still in use\n        # by the main stream at the time we start copying to next_*:\n        # self.stream.wait_stream(torch.cuda.current_stream())\n        with torch.cuda.stream(self.stream):\n            self.batch = move_to_cuda(self.batch)\n            # more code for the alternative if record_stream() doesn't work:\n            # copy_ will record the use of the pinned source tensor in this\n            # side stream.\n            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)\n            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)\n            # self.next_input = self.next_input_gpu\n            # self.next_target = self.next_target_gpu\n\n    def next(self, it):\n        torch.cuda.current_stream().wait_stream(self.stream)\n        batch = self.batch\n        if batch is not None:\n            record_cuda_stream(batch)\n        self.preload(it)\n        return batch\n\n    def __getattr__(self, name):\n        method = self.loader.__getattribute__(name)\n        return method\n\n\nclass InfiniteIterator(object):\n    \"\"\"iterate an iterable oobject infinitely\"\"\"\n    def __init__(self, iterable):\n        self.iterable = iterable\n        self.iterator = iter(iterable)\n\n    def __iter__(self):\n        while True:\n            try:\n                batch = next(self.iterator)\n            except StopIteration:\n                self.iterator = iter(self.iterable)\n                batch = next(self.iterator)\n            yield batch\n"
  },
  {
    "path": "src/datasets/dataset_base.py",
    "content": "from torch.utils.data import Dataset\nfrom PIL import Image\nimport io\nimport av\nimport torch\nimport numpy as np\nimport lmdb\nimport random\nimport decord\nfrom decord import VideoReader\nfrom src.datasets.data_utils import (\n    ImageResize, ImagePad, image_to_tensor)\nfrom src.utils.load_save import LOGGER\n\ndecord.bridge.set_bridge(\"torch\")\n\n\nclass AlproBaseDataset(Dataset):\n    \"\"\"\n    datalist: list(dicts)  # lightly pre-processed\n        {\n        \"type\": \"image\",\n        \"filepath\": \"/abs/path/to/COCO_val2014_000000401092.jpg\",\n        \"text\": \"A plate of food and a beverage are on a table.\",\n                # should be tokenized and digitized first?\n        ...\n        }\n    tokenizer:\n    max_img_size: int,\n    max_txt_len: int, max text sequence length, including special tokens.\n    fps: float, frame per second\n    num_frm: #frames to use as input.\n    \"\"\"\n\n    def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type='lmdb', fps=3, num_frm=3,\n                 frm_sampling_strategy=\"rand\", max_img_size=-1, max_txt_len=20):\n        self.fps = fps\n        self.num_frm = num_frm\n        self.frm_sampling_strategy = frm_sampling_strategy\n        self.datalist = datalist\n        self.tokenizer = tokenizer\n        self.max_txt_len = max_txt_len\n        self.max_img_size = max_img_size\n        self.img_resize = ImageResize(\n            max_img_size,\n            \"bilinear\")  # longer side will be resized to 1000\n        self.img_pad = ImagePad(\n            max_img_size, max_img_size)  # pad to 1000 * 1000\n\n        self.img_db_type = img_db_type\n\n        assert img_db_type in ['lmdb', 'rawvideo'], \"Invalid type for img_db_type, expected {'lmdb', 'rawvideo'}, found {}.\".format(img_db_type)\n\n        if self.img_db_type == 'lmdb':\n            self.env = lmdb.open(\n                img_lmdb_dir, readonly=True,\n                create=False)  # readahead=not _check_distributed()\n            self.txn = self.env.begin(buffers=True)\n        else:\n            self.img_db_dir = img_lmdb_dir\n\n    def __len__(self):\n        return len(self.datalist)\n\n    def __getitem__(self, index):\n        raise NotImplementedError\n\n    def _load_img(self, img_id):\n        \"\"\"Load and apply transformation to image\n\n        Returns:\n            torch.float, in [0, 255], (n_frm=1, c, h, w)\n        \"\"\"\n        raw_img = load_decompress_img_from_lmdb_value(\n            self.txn.get(str(img_id).encode(\"utf-8\"))\n        )\n        image_np = np.array(raw_img, dtype=np.uint8)  # (h, w, c)\n        raw_img_tensor = image_to_tensor(\n            image_np, keepdim=False).float()  # (c, h, w) [0, 255]\n        resized_img = self.img_resize(raw_img_tensor)\n        transformed_img = self.img_pad(\n            resized_img)  # (n_frm=1, c, h, w)\n        return transformed_img\n\n    @classmethod\n    def _is_extreme_aspect_ratio(cls, tensor, max_ratio=5.):\n        \"\"\" find extreme aspect ratio, where longer side / shorter side > max_ratio\n        Args:\n            tensor: (*, H, W)\n            max_ratio: float, max ratio (>1).\n        \"\"\"\n        h, w = tensor.shape[-2:]\n        return h / float(w) > max_ratio or h / float(w) < 1 / max_ratio\n\n    def _load_video(self, video_id, num_clips=None, clip_idx=None,\n                    safeguard_duration=False, video_max_pts=None):\n        \"\"\"Load and sample frames from video.\n        Apply transformation to the sampled frames.\n\n        Sample a clip:\n            - random: set num_clips and clip_idx to be None\n            - uniform: set num_clips=N, clip_idx=idx. e.g., num_clips=3\n                and clip_idx=1 will first segment the video into 3 clips,\n                then sample the 2nd clip.\n\n        Returns:\n            torch.float, in [0, 255], (n_frm=T, c, h, w)\n        \"\"\"\n        assert (num_clips is None) == (clip_idx is None), \"Both None, or both not None\"\n        # (T, C, H, W) [0, 255]\n        io_stream = io.BytesIO(self.txn.get(str(video_id).encode(\"utf-8\")))\n        raw_sampled_frms, video_max_pts = extract_frames_from_video_binary(\n            io_stream,\n            target_fps=self.fps,\n            num_frames=self.num_frm,\n            multi_thread_decode=False,\n            sampling_strategy=self.frm_sampling_strategy,\n            num_clips=num_clips,\n            clip_idx=clip_idx,\n            safeguard_duration=safeguard_duration,\n            video_max_pts=video_max_pts\n        )\n\n        if raw_sampled_frms is None:\n            return None, None\n        elif self._is_extreme_aspect_ratio(raw_sampled_frms, max_ratio=5.):\n            print(\n                f\"Found extreme aspect ratio for video id {video_id}. Skip it\")\n            return None, None\n\n        raw_sampled_frms = raw_sampled_frms.float()\n        resized_frms = self.img_resize(raw_sampled_frms)\n        padded_frms = self.img_pad(resized_frms)\n        return padded_frms, video_max_pts\n\n\n    def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):\n        try:\n            if not height or not width:\n                vr = VideoReader(video_path)\n            else:\n                vr = VideoReader(video_path, width=width, height=height)\n\n            vlen = len(vr)\n\n            if start_time or end_time:\n                assert fps > 0, 'must provide video fps if specifying start and end time.'\n\n                start_idx = min(int(start_time * fps), vlen)\n                end_idx = min(int(end_time * fps), vlen)\n            else:\n                start_idx, end_idx = 0, vlen\n\n            if self.frm_sampling_strategy == 'uniform':\n                frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)\n            elif self.frm_sampling_strategy == 'nlvl_uniform':\n                frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm).astype(int)\n            elif self.frm_sampling_strategy == 'nlvl_rand':\n                frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm).astype(int)\n\n                # generate some random perturbations\n                strides = [frame_indices[i] - frame_indices[i-1] for i in range(1, len(frame_indices))] + [vlen - frame_indices[-1]]\n                pertube = np.array([np.random.randint(0, stride) for stride in strides])\n\n                frame_indices = frame_indices + pertube\n\n            elif self.frm_sampling_strategy == 'rand':\n                frame_indices = sorted(random.sample(range(vlen), self.num_frm))\n            elif self.frm_sampling_strategy == 'headtail':\n                frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))\n                frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))\n                frame_indices = frame_indices_head + frame_indices_tail\n            else:\n                raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))\n\n            raw_sample_frms = vr.get_batch(frame_indices)\n        except Exception as e:\n            return None\n\n        raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)\n\n        return raw_sample_frms\n\ndef img_collate(imgs):\n    \"\"\"\n    Args:\n        imgs:\n\n    Returns:\n        torch.tensor, (B, 3, H, W)\n    \"\"\"\n    w = imgs[0].width\n    h = imgs[0].height\n    tensor = torch.zeros(\n        (len(imgs), 3, h, w), dtype=torch.uint8).contiguous()\n    for i, img in enumerate(imgs):\n        nump_array = np.array(img, dtype=np.uint8)\n        if (nump_array.ndim < 3):\n            nump_array = np.expand_dims(nump_array, axis=-1)\n        # (H, W, 3) --> (3, H, W)\n        nump_array = np.rollaxis(nump_array, 2)\n        tensor[i] += torch.from_numpy(nump_array)\n    return tensor\n"
  },
  {
    "path": "src/datasets/dataset_pretrain_sparse.py",
    "content": "import os\nimport json\nimport random\n\nimport torch\nimport spacy\nfrom torch.utils.data.dataloader import default_collate\nfrom src.utils.logger import LOGGER\nfrom src.utils.basic_utils import flat_list_of_lists, save_frames_grid\nfrom src.datasets.data_utils import VideoRandomSquareCrop, VideoResizeSquare, mask_batch_text_tokens, select_batch_text_pivots\nfrom src.datasets.dataset_base import AlproBaseDataset, img_collate\n\nfrom src.datasets.randaugment import TemporalConsistentRandomAugment, RandomAugment\n\nfrom torch.utils.data import Dataset\n\nfrom torchvision import transforms\nfrom PIL import Image\nimport numpy as np\n\n\nclass AlproPretrainSparseDataset(AlproBaseDataset):\n    \"\"\"\n    datalist: list(tuples)  each tuple is (img_id, list(dicts)),\n        each dict {\n            \"type\": \"image\",\n            \"filepath\": \"/abs/path/to/COCO_val2014_000000401092.jpg\",\n            \"text\": \"A plate of food and a beverage are on a table.\",  # should be tokenized and digitized first?\n            ...\n            }\n    tokenizer:\n    max_img_size: int,\n    max_txt_len: int, max text sequence length, including special tokens.\n    vis_format: str, image or video, used to decide data loading method.\n    \"\"\"\n    def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type, txt_dir,\n                video_fmt='.mp4', crop_size=256, resize_size=288, fps=3, num_frm=3, frm_sampling_strategy=\"rand\",\n                max_img_size=1000, max_txt_len=20,\n                use_itm=True, is_train=True):\n        super(AlproPretrainSparseDataset, self).__init__(\n            datalist, tokenizer, img_lmdb_dir, \n            img_db_type=img_db_type,\n            fps=fps, \n            num_frm=num_frm, \n            frm_sampling_strategy=frm_sampling_strategy,\n            max_img_size=max_img_size, \n            max_txt_len=max_txt_len)\n        self.use_itm = use_itm\n\n        self.txt_dir = txt_dir\n        self.video_fmt = video_fmt\n\n        self.crop_size = crop_size\n        self.video_random_cropper = VideoRandomSquareCrop(crop_size)\n\n        self.resize_size = resize_size\n\n        self.is_train = is_train\n\n        if self.is_train:\n            self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip'])     \n        else:\n            self.randaug = None\n\n    def __len__(self):\n        return len(self.datalist)\n\n    def __getitem__(self, index):\n        start_time = None\n        end_time = None\n\n        # fetch video\n        num_retries = 10  # skip error videos\n\n        for _ in range(num_retries):\n            data_sample = self.datalist.iloc[index]\n\n            video_id = str(data_sample.video_id)\n            txt_len = int(data_sample.txt_len)\n\n            if hasattr(data_sample, 'text'):\n                text = data_sample.text.strip()\n            else:\n                raise NotImplementedError(\"Un-supported text annotation format.\")\n\n            # fetch video\n            video_path = os.path.join(self.img_db_dir, video_id + self.video_fmt) \n\n            # read with retries\n            for i in range(3):\n                img_array = self._load_video_from_path_decord(video_path, height=self.resize_size, width=self.resize_size)\n\n                if img_array is not None:\n                    break\n\n            if img_array is not None:\n                t, c, h, w = img_array.shape\n\n            # Select a random video if the current video was not able to access.\n            if img_array is None:\n                LOGGER.info(f\"Failed to load examples with video: {video_path}. \"\n                            f\"Will randomly sample an example as a replacement.\")\n                index = random.randint(0, len(self) - 1)\n                continue\n            else:\n                # square crop\n                img_array = self.video_random_cropper(img_array)\n\n                if self.randaug:\n                    img_array = self.randaug(img_array.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n\n                break\n        else:\n            raise RuntimeError(f\"Failed to fetch video after {num_retries} retries.\")\n        \n        examples = [{'text_str': text, 'itm_label': 1}]\n\n        return dict(\n            img=img_array,  # (T, C, H, W)\n            examples=examples,\n            n_examples=len(examples),  # used to create image feature copies.\n            type='video'\n        )\n\nclass PretrainImageTextDataset(Dataset):\n    def __init__(self, datalist, tokenizer, is_train=True, crop_size=256, resize_size=288, num_frm=4, max_txt_len=40):\n        self.datalist = datalist\n        self.max_txt_len = max_txt_len\n\n        self.crop_size = crop_size\n        self.resize_size = resize_size\n        self.num_frms = num_frm\n\n        self.is_train = is_train\n\n        self.transform = transforms.Compose([                        \n                transforms.RandomResizedCrop(self.crop_size, scale=(0.2, 1.0), interpolation=Image.BICUBIC),\n                transforms.RandomHorizontalFlip(),\n                RandomAugment(2,7,isPIL=True,augs=['Identity','Brightness','Sharpness',\n                                                'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate'])     \n            ])    \n        \n    def __len__(self):\n        return len(self.datalist)\n\n    def __getitem__(self, index):\n        start_time = None\n        end_time = None\n\n        # fetch video\n        num_retries = 10  # skip error videos\n\n        for _ in range(num_retries):\n            data_sample = self.datalist[index]\n\n            try:\n                if type(data_sample['caption']) == list:\n                    text = random.choice(data_sample['caption'])\n                else:\n                    text = data_sample['caption']\n            \n                img_path = data_sample['image']\n                img_arr = Image.open(img_path).convert('RGB')   \n                img_arr = self.transform(img_arr)\n                img_arr = np.asarray(img_arr, dtype=np.float32).transpose(2, 0, 1)\n                img_arr = torch.from_numpy(img_arr).unsqueeze(0)\n                img_arr = img_arr.repeat(self.num_frms, 1, 1, 1)\n\n            except Exception as e:\n                img_arr = None\n\n            if img_arr is not None:\n                t, c, h, w = img_arr.shape\n\n            # Select a random video if the current video was not able to access.\n            if img_arr is None:\n                LOGGER.info(f\"Failed to load examples with image: {img_path}. \"\n                            f\"Will randomly sample an example as a replacement.\")\n                index = random.randint(0, len(self) - 1)\n                continue\n            else:\n                break\n        else:\n            raise RuntimeError(f\"Failed to fetch image after {num_retries} retries.\")\n        \n        examples = [{'text_str': text, 'itm_label': 1}]\n\n        return dict(\n            img=img_arr,  # (T, C, H, W)\n            examples=examples,\n            n_examples=len(examples),  # used to create image feature copies.\n            type='img'\n        )\n\n\nclass PretrainCollator(object):\n    \"\"\"is_train is kept here if we want to remove\n    the randomness during validation of MLM accuracy.\n    In that case, instantiate two PretrainCollator\"\"\"\n    def __init__(self, tokenizer, \n                 mlm=True, mlm_probability=0.15,\n                 patch_size=16,\n                 mpm=True,\n                 max_length=20, is_train=True):\n        self.tokenizer = tokenizer\n        self.mlm = mlm\n        self.mlm_probability = mlm_probability\n        self.max_length = max_length\n        self.is_train = is_train\n\n        self.mpm = mpm\n        self.patch_size = patch_size\n\n    def collate_batch(self, batch):\n        if isinstance(batch[0][\"img\"], torch.Tensor):\n            v_collate = default_collate\n        else:\n            v_collate = img_collate\n        visual_inputs = v_collate([d[\"img\"] for d in batch])  # (B, #frm=1 or T, 3, H, W)\n        # group data\n        text_examples = flat_list_of_lists([d[\"examples\"] for d in batch])\n        n_examples_list = [d[\"n_examples\"] for d in batch]  # (B, )\n        # group elements data\n        batch_enc = self.tokenizer.batch_encode_plus(\n            [d[\"text_str\"] for d in text_examples],\n            max_length=self.max_length,\n            padding='max_length',\n            return_tensors=\"pt\",\n            truncation=True\n        )\n        text_input_ids = batch_enc.input_ids  # (B, L)\n        text_input_ids_no_mask = text_input_ids.clone()\n\n        if self.mlm:\n            text_input_ids, mlm_labels = mask_batch_text_tokens(\n                text_input_ids, self.tokenizer,\n                is_train=self.is_train)  # make mlm data\n        else:\n            text_input_ids, mlm_labels = text_input_ids, None\n        \n        text_input_mask = batch_enc.attention_mask  # (B, L)\n        itm_labels = default_collate(\n            [d[\"itm_label\"] for d in text_examples])  # (B, )\n        \n        erase_elems = [random_erase(e, patch_size=self.patch_size) for e in visual_inputs.clone()]\n\n        if self.mpm:\n            crop_visual_inputs = v_collate([elems[0] for elems in erase_elems])\n            mpm_masks = v_collate([elems[1] for elems in erase_elems])\n            context_visual_inputs = v_collate([elems[2] for elems in erase_elems])\n\n            return dict(\n                visual_inputs=visual_inputs,  # (B, #frm=1 or T, H, W, C)\n                crop_visual_inputs=crop_visual_inputs,\n                context_visual_inputs=context_visual_inputs,\n                mpm_mask=mpm_masks,\n                text_input_ids=text_input_ids_no_mask,\n                mlm_text_input_ids=text_input_ids,\n                mlm_labels=mlm_labels,\n                text_input_mask=text_input_mask, # used to exclude [PAD] token\n                itm_labels=itm_labels,\n                n_examples_list=n_examples_list,  # used to create image feature copies.\n                type=batch[0]['type']\n            )\n        else:\n            return dict(\n                visual_inputs=visual_inputs,  # (B, #frm=1 or T, H, W, C)\n                text_input_ids=text_input_ids_no_mask,\n                mlm_text_input_ids=text_input_ids,\n                mlm_labels=mlm_labels,\n                text_input_mask=text_input_mask, # used to exclude [PAD] token\n                itm_labels=itm_labels,\n                n_examples_list=n_examples_list,  # used to create image feature copies.\n                type=batch[0]['type']\n            )\n\ndef random_erase(input_img, patch_size, s_l=0.3, s_h=0.5, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255):\n    assert input_img.ndim == 4\n    img_t, img_c, img_h, img_w = input_img.shape\n\n    while True:\n        s = np.random.uniform(s_l, s_h) * img_h * img_w\n        r = np.random.uniform(r_1, r_2)\n        w = int(np.sqrt(s / r))\n        h = int(np.sqrt(s * r))\n        left = np.random.randint(0, img_w)\n        top = np.random.randint(0, img_h)\n\n        w = w - w % patch_size\n        h = h - h % patch_size\n\n        left = left - left % patch_size\n        top = top - top % patch_size\n\n        if left + w <= img_w and top + h <= img_h:\n            break\n\n    context_img = input_img.clone()\n    context_img[:, :, top: top + h, left: left + w] = 0\n\n    input_img = input_img[:, :, top: top + h, left: left + w]\n    pad = (left, img_w - left - w, top, img_h - top - h)\n    input_img = torch.nn.functional.pad(input_img, pad=pad, mode='constant', value=0.0)\n\n    img_masks = torch.ones_like(input_img)\n    img_masks[:, :, top: top+h, left: left+w] = 0\n\n    img_masks = torch.nn.functional.avg_pool2d(img_masks.float(), kernel_size=(patch_size, patch_size), stride=patch_size)\n    img_masks = torch.mean(img_masks, dim=(0, 1))\n\n    return input_img, img_masks, context_img"
  },
  {
    "path": "src/datasets/dataset_video_qa.py",
    "content": "import os\nimport torch\nimport random\nimport numpy as np\nimport copy\nfrom torch.utils.data.dataloader import default_collate\nfrom src.utils.basic_utils import flat_list_of_lists\nfrom src.utils.load_save import LOGGER\nfrom src.datasets.dataset_base import AlproBaseDataset\nfrom src.datasets.randaugment import TemporalConsistentRandomAugment\n\n\nclass AlproVideoQADataset(AlproBaseDataset):\n    \"\"\" This should work for both train and test (where labels are not available).\n    task_type: str, one of [action, frameqa, transition]\n        where action and transition are multiple-choice QA,\n            frameqa is opened QA similar to VQA.\n    datalist: list(tuples)  each tuple is (img_id, list(dicts)),\n        each dict\n    tokenizer:\n    max_img_size: int,\n    max_txt_len: int, max text sequence length, including special tokens.\n    return_label: bool, whether return label in __getitem__\n    random_sample_clips:\n    \"\"\"\n    open_ended_qa_names = [\"frameqa\", \"msrvtt_qa\", \"msvd_qa\"]\n\n    def __init__(self, task_type, datalist, tokenizer, img_lmdb_dir,\n                 fps=3, num_frm=3, frm_sampling_strategy=\"rand\",\n                 max_img_size=1000, max_txt_len=20, ans2label=None,\n                 ensemble_n_clips=1, return_label=True, is_train=False, random_sample_clips=True, \n                 video_fmt='.mp4', img_db_type='lmdb'):\n        super(AlproVideoQADataset, self).__init__(\n            datalist, tokenizer, img_lmdb_dir, img_db_type=img_db_type,\n            fps=fps, num_frm=num_frm,\n            frm_sampling_strategy=frm_sampling_strategy,\n            max_img_size=max_img_size, max_txt_len=max_txt_len)\n        self.ensemble_n_clips = ensemble_n_clips\n        self.return_label = return_label\n        self.is_train = is_train\n        self.task_type = task_type\n        self.ans2label = ans2label\n        self.num_labels = len(ans2label)\n        self.random_sample_clips = random_sample_clips\n        self.label2ans = {v: k for k, v in ans2label.items()}\n        self.qid2data = {d[\"question_id\"]: d for group in datalist for d in group[1]}\n\n        self.video_fmt = video_fmt\n\n        if self.is_train:\n            self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip'])     \n        else:\n            self.randaug = None\n\n    def __len__(self):\n        return len(self.datalist)\n\n\n    def __getitem__(self, index):\n        # skip error videos:\n        num_retries = 5\n        for _ in range(num_retries):\n            vid_id, examples = self.datalist[index]  # one video with multiple examples\n            if self.ensemble_n_clips > 1:\n                raise NotImplementedError('Do not support multiple clips for now.')\n            else:\n                video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt) \n                vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)\n\n            # Select a random video if the current video was not able to access.\n            if vid_frm_array is None:\n                LOGGER.info(f\"Failed to load examples with video: {vid_id}. \"\n                            f\"Will randomly sample an example as a replacement.\")\n                index = random.randint(0, len(self) - 1)\n                continue\n\n            if self.randaug:\n                vid_frm_array = self.randaug(vid_frm_array.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)\n\n            examples = [self._get_single_example(e) for e in examples]\n            return dict(\n                vid=vid_frm_array,\n                examples=examples,\n                n_examples=len(examples)  # used to create image feature copies.\n            )\n        else:\n            raise RuntimeError(f\"Failed to fetch video after {num_retries} retries.\")\n\n    def _get_single_example(self, data):\n        example = dict(\n            q_str=data[\"question\"],\n            question_id=data[\"question_id\"],\n            label=data[\"answer\"]\n        )\n        if self.task_type in self.open_ended_qa_names:\n            if self.return_label:\n                example[\"label\"] = self.ans2label[example[\"label\"]]\n        if not self.return_label:\n            example[\"label\"] = None\n        return example\n\n    def evaluate_qa(self, results):\n        \"\"\"\n        Args:\n            results: list(dict),\n              each dict is\n                {\n                    \"question_id\": int,\n                    \"answer\": int or float, either answer_idx (int)\n                }\n        Returns:\n            TGIF-QA score\n        \"\"\"\n        preds = []\n        gts = []\n        # for frameQA\n        answer_types = []\n        answer_type2idx = dict(\n            frameqa={\"object\": 0, \"number\": 1, \"color\": 2, \"location\": 3},\n            msrvtt_qa={k: idx for idx, k in enumerate([\"what\", \"who\", \"how\", \"where\", \"when\"])},\n            msvd_qa={k: idx for idx, k in enumerate([\"what\", \"who\", \"how\", \"where\", \"when\"])}\n        )\n\n        qid2pred_ans = {r[\"question_id\"]: r[\"answer\"] for r in results}\n        if self.task_type in self.open_ended_qa_names:  # convert ans_idx, int --> str\n            qid2pred_ans = {k: self.label2ans[v] for k, v in qid2pred_ans.items()}\n\n        for qid, pred_ans in qid2pred_ans.items():\n            preds.append(pred_ans)\n\n            gt_data = self.qid2data[qid]\n            gt_ans = gt_data[\"answer\"]\n            if self.task_type in self.open_ended_qa_names:\n                answer_types.append(answer_type2idx[self.task_type][gt_data[\"answer_type\"]])\n            gts.append(gt_ans)\n\n        preds = np.array(preds)\n        gts = np.array(gts)\n        metrics = dict()\n        # preds and gts are array of strings\n        metrics[\"overall_acc\"] = float(np.mean(preds == gts))\n        if self.task_type in self.open_ended_qa_names:\n            answer_types = np.array(answer_types)\n            ratios = dict()\n            for ans_type, ans_type_idx in answer_type2idx[self.task_type].items():\n                answer_type_mask = answer_types == ans_type_idx\n                answer_type_corrects = (\n                        preds[answer_type_mask] == gts[answer_type_mask])\n                metrics[f\"{ans_type}_acc\"] = float(\n                    np.mean(answer_type_corrects)) if len(answer_type_corrects) != 0 else 0\n                ratios[f\"{ans_type}_ratio\"] = [\n                    1. * len(answer_type_corrects) / len(answer_types),\n                    len(answer_type_corrects)]\n            metrics[\"ratios\"] = ratios\n        return metrics\n\n\nclass VideoQACollator(object):\n    def __init__(self, tokenizer, max_length=20, task_type=\"action\", n_options=5):\n        self.tokenizer = tokenizer\n        self.max_length = max_length\n        self.task_type = task_type\n        self.n_options = n_options\n\n    def collate_batch(self, batch):\n        v_collate = default_collate\n        visual_inputs = v_collate([d[\"vid\"] for d in batch])  # (B, T, 3, H, W)\n        # group data\n        text_examples = flat_list_of_lists([d[\"examples\"] for d in batch])\n        n_examples_list = [d[\"n_examples\"] for d in batch]  # (B, )\n        # group elements data\n        # directly concatenate question and option as a single seq.\n        if self.task_type in [\"action\", \"transition\"]:\n            text_str_list = flat_list_of_lists(\n                [[d[\"q_str\"] + \" \" + d[\"options_str_list\"][i] for i in range(self.n_options)]\n                 for d in text_examples]\n            )  # (B * n_options, )\n        else:\n            text_str_list = [d[\"q_str\"] for d in text_examples]  # (B, )\n        batch_enc = self.tokenizer.batch_encode_plus(\n            text_str_list,\n            max_length=self.max_length,\n            padding='max_length',\n            return_tensors=\"pt\",\n            truncation=True\n        )\n        text_input_ids = batch_enc.input_ids  # (B, L)\n        text_input_mask = batch_enc.attention_mask  # (B, L)\n\n        labels = default_collate([int(d[\"label\"]) for d in text_examples]) \\\n            if text_examples[0][\"label\"] is not None else None  # (B, #ans)\n        question_ids = [d[\"question_id\"] for d in text_examples]\n        return dict(\n            visual_inputs=visual_inputs,  # (B, #frm, H, W, C)\n            text_input_ids=text_input_ids,\n            text_input_mask=text_input_mask,\n            question_ids=question_ids,\n            labels=labels,\n            n_examples_list=n_examples_list  # used to create image feature copies.\n        )\n"
  },
  {
    "path": "src/datasets/dataset_video_retrieval.py",
    "content": "import random\nimport copy\nimport os\nimport torch\nimport numpy as np\nfrom torch.utils.data.dataloader import default_collate\nfrom src.utils.basic_utils import flat_list_of_lists\nfrom src.utils.load_save import LOGGER\nfrom src.datasets.dataset_base import AlproBaseDataset\nfrom src.datasets.randaugment import TemporalConsistentRandomAugment\n\n\nclass AlproVideoRetrievalDataset(AlproBaseDataset):\n    \"\"\" This should work for both train and test (where labels are not available).\n    datalist: list(tuples)  each tuple is (img_id, list(dicts)),\n        each dict\n    tokenizer:\n    max_img_size: int,\n    max_txt_len: int, max text sequence length, including special tokens.\n    random_sample_clips: bool, whether using randomly sampled N clips or always use uniformly sampled N clips\n    \"\"\"\n    def __init__(self, datalist, tokenizer, img_lmdb_dir,\n                 fps=3, num_frm=3, frm_sampling_strategy=\"rand\",\n                 max_img_size=1000, max_txt_len=40, itm_neg_size=1,\n                 ensemble_n_clips=1, random_sample_clips=True,\n                 video_fmt='.mp4', img_db_type='lmdb', is_train=False):\n        super(AlproVideoRetrievalDataset, self).__init__(\n            datalist, tokenizer, img_lmdb_dir, img_db_type=img_db_type,\n            fps=fps, num_frm=num_frm,\n            frm_sampling_strategy=frm_sampling_strategy,\n            max_img_size=max_img_size, max_txt_len=max_txt_len)\n        self.ensemble_n_clips = ensemble_n_clips\n        self.num_labels = 2\n        self.itm_neg_size = itm_neg_size\n        self.random_sample_clips = random_sample_clips\n        self.id2data = {\n            d[\"id\"]: d for group in datalist for d in group[1]}\n\n        self.is_train = is_train\n        self.video_fmt = video_fmt\n\n        if self.is_train:\n            self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip'])     \n        else:\n            self.randaug = None\n\n    def __len__(self):\n        return len(self.datalist)\n\n    def __getitem__(self, index):\n        # skip error videos:\n        num_retries = 5\n        for _ in range(num_retries):\n            vid_id, examples = self.datalist[index]  # one video with multiple examples\n            if self.ensemble_n_clips > 1:\n                raise NotImplementedError('Do not support multiple clips for now.')\n            else:\n                video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt) \n                vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)\n\n            # Select a random video if the current video was not able to access.\n            if vid_frm_array is None:\n                LOGGER.info(f\"Failed to load examples with video: {vid_id}. \"\n                            f\"Will randomly sample an example as a replacement.\")\n                index = random.randint(0, len(self) - 1)\n                continue\n            sampled_examples = []\n            for e in examples:\n                s = self._get_single_example(e, index)\n                if isinstance(s, dict):\n                    sampled_examples.append(s)\n                else:\n                    sampled_examples.extend(s)\n            return dict(\n                vid=vid_frm_array,\n                examples=sampled_examples,\n                n_examples=len(sampled_examples)  # used to create image feature copies.\n            )\n        else:\n            raise RuntimeError(\n             f\"Failed to fetch video after {num_retries} retries.\")\n\n    def _get_single_example(self, data, index):\n        examples = []\n\n        text_str = data[\"txt\"]\n        itm_label = 1  # positive pair\n        examples.append(dict(\n            text_str=text_str,\n            itm_label=itm_label\n        ))\n        return examples\n\n\nclass VideoRetrievalCollator(object):\n    def __init__(self, tokenizer, max_length=40):\n        self.tokenizer = tokenizer\n        self.max_length = max_length\n\n    def collate_batch(self, batch):\n        # FIXME there is a chance that two captions associated with the same video are batched together. Might need to fix.\n        v_collate = default_collate\n        visual_inputs = v_collate([d[\"vid\"] for d in batch])  # (B, T, 3, H, W)\n        # group data\n        text_examples = flat_list_of_lists([d[\"examples\"] for d in batch])\n        n_examples_list = [d[\"n_examples\"] for d in batch]  # (B, )\n        # group elements data\n        # directly concatenate question and option as a single seq.\n        text_str_list = [d[\"text_str\"] for d in text_examples]  # (B, )\n        batch_enc = self.tokenizer.batch_encode_plus(\n            text_str_list,\n            max_length=self.max_length,\n            padding='max_length',\n            return_tensors=\"pt\",\n            truncation=True\n        )\n        text_input_ids = batch_enc.input_ids  # (B, L)\n        text_input_mask = batch_enc.attention_mask  # (B, L)\n\n        if \"itm_label\" in text_examples[0]:\n            itm_labels = default_collate(\n                [d[\"itm_label\"] for d in text_examples])  # (B, )\n        else:\n            itm_labels = None\n\n        if \"id\" in text_examples[0]:\n            caption_ids = [d[\"id\"] for d in text_examples]  # (B, )\n        else:\n            caption_ids = None\n        collated_batch = dict(\n            visual_inputs=visual_inputs,  # (B, #frm, H, W, C)\n            text_input_ids=text_input_ids,\n            text_input_mask=text_input_mask,\n            caption_ids=caption_ids,  # list(int), example ids,\n            labels=itm_labels,\n            n_examples_list=n_examples_list  # used to create image feature copies.\n        )\n        if \"vid_id\" in batch[0] and len(batch) == 1:\n            collated_batch[\"vid_id\"] = batch[0][\"vid_id\"]\n        return collated_batch\n\n\nclass AlproVideoRetrievalEvalDataset(AlproBaseDataset):\n    \"\"\" Sample by video/image, calculate scores between each video with all the text\n    and loop through all the videos. Each batch will only contain a single video,\n    but multiple text.\n\n    datalist: list(dict), each dict\n    tokenizer:\n    max_img_size: int,\n    max_txt_len: int, max text sequence length, including special tokens.\n    \"\"\"\n    def __init__(self, datalist, tokenizer, img_lmdb_dir,\n                 fps=3, num_frm=3, frm_sampling_strategy=\"rand\",\n                 max_img_size=1000, max_txt_len=40, ensemble_n_clips=1,\n                 video_fmt='.mp4', img_db_type='lmdb'):\n        self.ensemble_n_clips = ensemble_n_clips\n        super(AlproVideoRetrievalEvalDataset, self).__init__(\n            datalist, tokenizer, img_lmdb_dir,\n            fps=fps, num_frm=num_frm,\n            frm_sampling_strategy=frm_sampling_strategy,\n            max_img_size=max_img_size, max_txt_len=max_txt_len,\n            img_db_type=img_db_type)\n        # id is unique id per caption/example\n        for i, d in enumerate(self.datalist):\n            assert i == d[\"id\"]\n        self.gt_cap_id2vid_id = {d[\"id\"]: d[\"vid_id\"] for d in datalist}\n        self.cap_id2data = {d[\"id\"]: d for d in datalist}\n        self.batches, self.text_batch = self._prepare_batches_by_video()\n        self.id2data = {d[\"id\"]: d for d in self.datalist}\n\n        self.video_fmt = video_fmt\n\n    def __len__(self):\n        return len(self.batches)\n\n    def __getitem__(self, index):\n        # skip error videos:\n        batch = dict()\n\n        batch[\"vid_id\"] = self.batches[index][\"vid_id\"]  # one video with multiple examples\n        batch[\"examples\"] = self.text_batch[\"examples\"]\n        batch[\"n_examples\"] = self.text_batch[\"n_examples\"]\n        batch[\"ids\"] = self.text_batch[\"ids\"]\n\n        if self.ensemble_n_clips > 1:\n            raise NotImplementedError('Do not support multiple clips for now.')\n        else:\n            # if self.is_train and self.random_sample_clips:\n            vid_id = batch[\"vid_id\"]\n\n            video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt) \n            vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)\n\n        batch[\"vid\"] = vid_frm_array\n        return batch\n\n    def _prepare_batches_by_video(self):\n        \"\"\"create batches where each batch contains a single video with multiple text\"\"\"\n        text_list = []\n        for d in self.datalist:\n            text_list.append(dict(\n                text_str=d[\"txt\"],\n                id=d[\"id\"],\n            ))\n        text_batch = dict(\n            vid_id=None,\n            examples=text_list,\n            n_examples=len(text_list),\n            ids=[d[\"id\"] for d in text_list]\n        )\n\n        # make 1000 batches for 1000video x 1000text combinations.\n        # each batch contains 1video x 1000text\n        batches = []\n        for idx, d in enumerate(self.datalist):\n             #_batch = copy.deepcopy(text_batch)\n            _batch = dict()\n            _batch[\"vid_id\"] = d[\"vid_id\"]\n            batches.append(_batch)\n        return batches, text_batch\n"
  },
  {
    "path": "src/datasets/randaugment.py",
    "content": "import cv2\nimport numpy as np\nimport torch\n\n\n## aug functions\ndef identity_func(img):\n    return img\n\n\ndef autocontrast_func(img, cutoff=0):\n    '''\n        same output as PIL.ImageOps.autocontrast\n    '''\n    n_bins = 256\n\n    def tune_channel(ch):\n        n = ch.size\n        cut = cutoff * n // 100\n        if cut == 0:\n            high, low = ch.max(), ch.min()\n        else:\n            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])\n            low = np.argwhere(np.cumsum(hist) > cut)\n            low = 0 if low.shape[0] == 0 else low[0]\n            high = np.argwhere(np.cumsum(hist[::-1]) > cut)\n            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]\n        if high <= low:\n            table = np.arange(n_bins)\n        else:\n            scale = (n_bins - 1) / (high - low)\n            offset = -low * scale\n            table = np.arange(n_bins) * scale + offset\n            table[table < 0] = 0\n            table[table > n_bins - 1] = n_bins - 1\n        table = table.clip(0, 255).astype(np.uint8)\n        return table[ch]\n\n    channels = [tune_channel(ch) for ch in cv2.split(img)]\n    out = cv2.merge(channels)\n    return out\n\n\ndef equalize_func(img):\n    '''\n        same output as PIL.ImageOps.equalize\n        PIL's implementation is different from cv2.equalize\n    '''\n    n_bins = 256\n\n    def tune_channel(ch):\n        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])\n        non_zero_hist = hist[hist != 0].reshape(-1)\n        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)\n        if step == 0: return ch\n        n = np.empty_like(hist)\n        n[0] = step // 2\n        n[1:] = hist[:-1]\n        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)\n        return table[ch]\n\n    channels = [tune_channel(ch) for ch in cv2.split(img)]\n    out = cv2.merge(channels)\n    return out\n\n\ndef rotate_func(img, degree, fill=(0, 0, 0)):\n    '''\n    like PIL, rotate by degree, not radians\n    '''\n    H, W = img.shape[0], img.shape[1]\n    center = W / 2, H / 2\n    M = cv2.getRotationMatrix2D(center, degree, 1)\n    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)\n    return out\n\n\ndef horizontal_flip_func(img):\n    '''\n    [dxli]\n    horizontally flip an image.\n    '''\n    out = cv2.flip(img, 1)\n\n    return out\n\n\ndef solarize_func(img, thresh=128):\n    '''\n        same output as PIL.ImageOps.posterize\n    '''\n    table = np.array([el if el < thresh else 255 - el for el in range(256)])\n    table = table.clip(0, 255).astype(np.uint8)\n    out = table[img]\n    return out\n\n\ndef color_func(img, factor):\n    '''\n        same output as PIL.ImageEnhance.Color\n    '''\n    ## implementation according to PIL definition, quite slow\n    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]\n    #  out = blend(degenerate, img, factor)\n    #  M = (\n    #      np.eye(3) * factor\n    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)\n    #  )[np.newaxis, np.newaxis, :]\n    M = (\n            np.float32([\n                [0.886, -0.114, -0.114],\n                [-0.587, 0.413, -0.587],\n                [-0.299, -0.299, 0.701]]) * factor\n            + np.float32([[0.114], [0.587], [0.299]])\n    )\n    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)\n    return out\n\n\ndef contrast_func(img, factor):\n    \"\"\"\n        same output as PIL.ImageEnhance.Contrast\n    \"\"\"\n    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))\n    table = np.array([(\n        el - mean) * factor + mean\n        for el in range(256)\n    ]).clip(0, 255).astype(np.uint8)\n    out = table[img]\n    return out\n\n\ndef brightness_func(img, factor):\n    '''\n        same output as PIL.ImageEnhance.Contrast\n    '''\n    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)\n    out = table[img]\n    return out\n\n\ndef sharpness_func(img, factor):\n    '''\n    The differences the this result and PIL are all on the 4 boundaries, the center\n    areas are same\n    '''\n    kernel = np.ones((3, 3), dtype=np.float32)\n    kernel[1][1] = 5\n    kernel /= 13\n    degenerate = cv2.filter2D(img, -1, kernel)\n    if factor == 0.0:\n        out = degenerate\n    elif factor == 1.0:\n        out = img\n    else:\n        out = img.astype(np.float32)\n        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]\n        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)\n        out = out.astype(np.uint8)\n    return out\n\n\ndef shear_x_func(img, factor, fill=(0, 0, 0)):\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, factor, 0], [0, 1, 0]])\n    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)\n    return out\n\n\ndef translate_x_func(img, offset, fill=(0, 0, 0)):\n    '''\n        same output as PIL.Image.transform\n    '''\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, -offset], [0, 1, 0]])\n    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)\n    return out\n\n\ndef translate_y_func(img, offset, fill=(0, 0, 0)):\n    '''\n        same output as PIL.Image.transform\n    '''\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, 0], [0, 1, -offset]])\n    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)\n    return out\n\n\ndef posterize_func(img, bits):\n    '''\n        same output as PIL.ImageOps.posterize\n    '''\n    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))\n    return out\n\n\ndef shear_y_func(img, factor, fill=(0, 0, 0)):\n    H, W = img.shape[0], img.shape[1]\n    M = np.float32([[1, 0, 0], [factor, 1, 0]])\n    out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)\n    return out\n\n\n# def cutout_func(img, pad_size, replace=(0, 0, 0)):\n#     replace = np.array(replace, dtype=np.uint8)\n#     H, W = img.shape[0], img.shape[1]\n#     rh, rw = np.random.random(2)\n#     pad_size = pad_size // 2\n#     ch, cw = int(rh * H), int(rw * W)\n#     x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)\n#     y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)\n#     out = img.copy()\n#     out[x1:x2, y1:y2, :] = replace\n#     return out\n\n\n### level to args\ndef enhance_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        return ((level / MAX_LEVEL) * 1.8 + 0.1,)\n    return level_to_args\n\n\ndef shear_level_to_args(MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * 0.3\n        # if np.random.random() > 0.5: level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef translate_level_to_args(translate_const, MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * float(translate_const)\n        # if np.random.random() > 0.5: level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * cutout_const)\n        return (level, replace_value)\n\n    return level_to_args\n\n\ndef solarize_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * 256)\n        return (level, )\n    return level_to_args\n\n\ndef none_level_to_args(level):\n    return ()\n\n\ndef posterize_level_to_args(MAX_LEVEL):\n    def level_to_args(level):\n        level = int((level / MAX_LEVEL) * 4)\n        return (level, )\n    return level_to_args\n\n\ndef rotate_level_to_args(MAX_LEVEL, replace_value):\n    def level_to_args(level):\n        level = (level / MAX_LEVEL) * 30\n        # if np.random.random() < 0.5:\n        #     level = -level\n        return (level, replace_value)\n\n    return level_to_args\n\n\nfunc_dict = {\n    'Identity': identity_func,\n    # 'AutoContrast': autocontrast_func,\n    'Equalize': equalize_func,\n    'Rotate': rotate_func,\n    'Solarize': solarize_func,\n    'Color': color_func,\n    'Contrast': contrast_func,\n    'Brightness': brightness_func,\n    'Sharpness': sharpness_func,\n    'ShearX': shear_x_func,\n    'TranslateX': translate_x_func,\n    'TranslateY': translate_y_func,\n    'Posterize': posterize_func,\n    'ShearY': shear_y_func,\n    'HorizontalFlip': horizontal_flip_func # [dxli]\n}\n\ntranslate_const = 10\nMAX_LEVEL = 10\nreplace_value = (128, 128, 128)\narg_dict = {\n    'Identity': none_level_to_args,\n    # 'AutoContrast': none_level_to_args,\n    'Equalize': none_level_to_args,\n    'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),\n    'Solarize': solarize_level_to_args(MAX_LEVEL),\n    'Color': enhance_level_to_args(MAX_LEVEL),\n    'Contrast': enhance_level_to_args(MAX_LEVEL),\n    'Brightness': enhance_level_to_args(MAX_LEVEL),\n    'Sharpness': enhance_level_to_args(MAX_LEVEL),\n    'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),\n    'TranslateX': translate_level_to_args(\n        translate_const, MAX_LEVEL, replace_value\n    ),\n    'TranslateY': translate_level_to_args(\n        translate_const, MAX_LEVEL, replace_value\n    ),\n    'Posterize': posterize_level_to_args(MAX_LEVEL),\n    'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),\n    'HorizontalFlip': none_level_to_args  # [dxli]\n}\n\n\nclass TemporalConsistentRandomAugment(object):\n\n    def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):\n        self.N = N\n        self.M = M\n        self.p = p\n        self.tensor_in_tensor_out = tensor_in_tensor_out\n        if augs:\n            self.augs = augs       \n        else:\n            self.augs = list(arg_dict.keys())\n\n    def get_random_ops(self):\n        sampled_ops = np.random.choice(self.augs, self.N, replace=False)\n        # return [(op, 0.5, self.M) for op in sampled_ops]\n        return [(op, self.M) for op in sampled_ops]\n\n    def __call__(self, frames):\n        assert frames.shape[-1] == 3, 'Expecting last dimension for 3-channels RGB (b, h, w, c).'\n        \n        if self.tensor_in_tensor_out:\n            frames = frames.numpy().astype(np.uint8)\n        \n        num_frames = frames.shape[0]\n\n        ops = num_frames * [self.get_random_ops()]\n        apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]\n\n        frames = torch.stack(list(map(self._aug, frames, ops, apply_or_not)), dim=0).float()\n\n        return frames\n\n    def _aug(self, img, ops, apply_or_not):\n        for i, (name, level) in enumerate(ops):\n            if not apply_or_not[i]:\n                continue\n            args = arg_dict[name](level)\n            img = func_dict[name](img, *args) \n        return torch.from_numpy(img)\n\nclass RandomAugment(object):\n\n    def __init__(self, N=2, M=10, isPIL=False, augs=[]):\n        self.N = N\n        self.M = M\n        self.isPIL = isPIL\n        if augs:\n            self.augs = augs       \n        else:\n            self.augs = list(arg_dict.keys())\n\n    def get_random_ops(self):\n        sampled_ops = np.random.choice(self.augs, self.N)\n        return [(op, 0.5, self.M) for op in sampled_ops]\n\n    def __call__(self, img):\n        if self.isPIL:\n            img = np.array(img)            \n        ops = self.get_random_ops()\n        for name, prob, level in ops:\n            if np.random.random() > prob:\n                continue\n            args = arg_dict[name](level)\n            img = func_dict[name](img, *args) \n        return img\n\n\ndef save_frames_grid(img_array, out_path):\n    import torch\n    from torchvision.utils import make_grid\n    from PIL import Image\n\n    if len(img_array.shape) == 3:\n        img_array = img_array.unsqueeze(0)\n    elif len(img_array.shape) == 5:\n        b, t, c, h, w = img_array.shape\n        img_array = img_array.view(-1, c, h, w)\n    elif len(img_array.shape) == 4:\n        pass\n    else:\n        raise NotImplementedError('Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored.')\n    \n    assert img_array.shape[1] == 3, \"Exepcting input shape of (H, W, 3), i.e. RGB-only.\"\n    \n    grid = make_grid(img_array)\n    ndarr = grid.permute(1, 2, 0).to('cpu', torch.uint8).numpy()\n\n    img = Image.fromarray(ndarr)\n\n    img.save(out_path)\n\n\ndef stack(data, dim=0):\n    shape = data[0].shape  # need to handle empty list\n    shape = shape[:dim] + (len(data),) + shape[dim:]\n    x = torch.cat(data, dim=dim)\n    x = x.reshape(shape)\n    # need to handle case where dim=-1\n    # which is not handled here yet\n    # but can be done with transposition\n    return x\n\n\nif __name__ == '__main__':\n    import decord, os\n    from decord import VideoReader\n    decord.bridge.set_bridge('torch')\n\n    root_dir = '/export/share/dongxuli/data/webvid2m/postprocess/downsampled_videos'\n    video_id = '1058234725.mp4'\n\n    video_path = os.path.join(root_dir, video_id) \n    vr = VideoReader(video_path)\n\n    frames = vr.get_batch([1, 3, 5, 7, 9])\n    frames = frames\n\n    # a = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast', 'Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate'])     \n    a = TemporalConsistentRandomAugment(N=1, M=5, augs=['HorizontalFlip'])\n\n    print(frames[0].shape)\n    save_frames_grid(frames.permute(0, 3, 1, 2), 'before.jpg')\n\n    after_frames = a(frames)\n    print(after_frames.shape)\n\n    save_frames_grid(after_frames.permute(0, 3, 1, 2), 'after.jpg')"
  },
  {
    "path": "src/modeling/alpro_models.py",
    "content": "import copy\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm\nfrom einops import rearrange, reduce, repeat\nfrom horovod import torch as hvd\nfrom src.modeling.timesformer.vit import TimeSformer\nfrom src.modeling.xbert import (BertEmbeddings, BertEncoder, BertForMaskedLM,\n                                BertLMPredictionHead, BertModel, BertPooler,\n                                BertPreTrainedModel, BertPreTrainingHeads)\nfrom src.utils.basic_utils import load_json, load_jsonl, save_frames_grid\nfrom src.utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss, MSELoss\n\n\nclass AlproBaseModel(nn.Module):\n    def __init__(self, config=None, input_format='RGB', video_enc_cfg=None, temp=0.07):\n        super().__init__()\n        \n        self.temp = nn.Parameter(torch.ones([]) * temp)   \n\n        self.bert_config = config\n\n        visual_model_cls = eval(video_enc_cfg['cls'])\n\n        self.visual_encoder = visual_model_cls(model_cfg=video_enc_cfg, input_format=input_format, cross_attention_config=config)\n        self.text_encoder = BertForMaskedLM.from_pretrained('bert-base-uncased', config=self.bert_config)\n\n        # FIXME make them configurable\n        embed_dim = 256\n        vision_width = 768\n\n        text_width = self.bert_config.hidden_size\n\n        self.vision_proj = nn.Linear(vision_width, embed_dim)\n        self.text_proj = nn.Linear(text_width, embed_dim)         \n\n        self.itc_token_type = self.bert_config.itc_token_type\n        self.itm_head = nn.Linear(text_width, 2)     \n\n\n    def load_separate_ckpt(self, visual_weights_path=None, bert_weights_path=None):\n        if visual_weights_path:\n            self.visual_encoder.load_state_dict(visual_weights_path)\n\n        # if bert_weights_path:\n        #     load_multimodal_encoder_state_dict_with_mismatch(self.cross_encoder, bert_weights_path)\n        #     load_mlm_head_state_dict_with_mismatch(self.mlm_head, bert_weights_path)\n\n    # def freeze_cnn_backbone(self):\n    #     for n, p in self.visual_encoder.feature.named_parameters():\n    #         p.requires_grad = False\n\n\nclass AlproForPretrain(AlproBaseModel):\n    def __init__(self, config, video_enc_cfg, input_format='RGB'):\n        super(AlproForPretrain, self).__init__(config, input_format=input_format, video_enc_cfg=video_enc_cfg)\n\n        # model for generating pseudo labels\n        self.prompter = Prompter(config, video_enc_cfg)\n\n        self.use_mask_prob = 0\n        self.mpm_head = nn.Sequential(\n            nn.Linear(config.hidden_size,\n                    config.hidden_size * 2),\n            nn.ReLU(True),\n            nn.Linear(config.hidden_size * 2, self.prompter.entity_num)\n        )\n\n    def build_text_prompts(self, prompts):\n        self.prompter.build_text_prompts(prompts)\n\n    def get_pseudo_labels(self, batch):\n        return self.prompter.get_pseudo_labels(batch)\n\n    def forward(self, batch):\n        with torch.no_grad():\n            self.temp.clamp_(0.001,0.5)\n\n        visual_inputs = batch['visual_inputs']\n\n        use_mpm = 'mpm_mask' in batch\n        if use_mpm:\n            context_visual_inputs = batch['context_visual_inputs']\n\n        device = visual_inputs.device\n        b, t, c, h, w = visual_inputs.shape\n\n        # forward image and text features\n        # feats are normalized embeds\n        if use_mpm and np.random.uniform() < self.use_mask_prob:\n            video_embeds_total = self._forward_visual_embeds(torch.cat([visual_inputs, context_visual_inputs], dim=0))\n            # split for unmasked and masked\n            video_embeds, context_video_embeds = video_embeds_total[:b], video_embeds_total[b:]\n        else:\n            video_embeds = self._forward_visual_embeds(visual_inputs)\n            context_video_embeds = video_embeds\n\n        # we compute normalized feats for unmasked visual inputs only, used for ITC\n        video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  \n        video_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)\n        \n        # text embeddings and features\n        text_embeds, text_feat = self._forward_text_feats(batch)\n\n        # ========== (in-batch) ITC loss ==========\n        gathered_video_feats = hvd.allgather(video_feat)\n        gathered_text_feats = hvd.allgather(text_feat)\n\n        assert self.itc_token_type == 'cls', 'Support CLS tokens for ITC only, find {}.'.format(self.itc_token_type)\n        sim_v2t = video_feat @ gathered_text_feats.t() / self.temp \n        sim_t2v = text_feat @ gathered_video_feats.t() / self.temp \n                             \n        # [IMPORTANT] be very careful when initializing the GT sim_v2t \n        # allgather return the concatenated features in the order of local_rank()\n        sim_targets = torch.zeros_like(sim_v2t)\n\n        local_rank = hvd.local_rank()\n        b_start, b_end = b * local_rank, b * (local_rank + 1)\n        sim_targets[:, b_start: b_end] = torch.eye(b)\n\n        loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1)*sim_targets,dim=1).mean()\n        loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1)*sim_targets,dim=1).mean() \n\n        vtc_loss = (loss_v2t+loss_t2v) / 2\n\n        # ========= VTM ==========\n        text_atts = batch['text_input_mask']\n\n        # non-masked text and non-masked image \n        vtm_loss, vtm_logits, vtm_labels, encoder_outputs_pos = self.compute_vtm(text_embeds=text_embeds, \n                                                                                 text_atts=text_atts, \n                                                                                 video_embeds=video_embeds, \n                                                                                 video_atts=video_atts, \n                                                                                 sim_v2t=sim_v2t.clone(), # for hard mining\n                                                                                 sim_t2v=sim_t2v.clone(), # for hard mining\n                                                                                 return_encoder_out=True\n                                                                                )\n\n        # ========= MLM ========== \n        # masked text and non-masked image\n        if 'mlm_labels' in batch: \n            mlm_labels = batch['mlm_labels']\n            mlm_text_input_ids = batch['mlm_text_input_ids']\n\n            mlm_loss, mlm_logits, mlm_labels = self.compute_mlm(input_ids=mlm_text_input_ids,\n                                                                text_input_mask=text_atts,\n                                                                video_embeds=video_embeds, \n                                                                video_atts=video_atts,\n                                                                mlm_labels=mlm_labels\n                                                                )\n        else:\n            mlm_logits = mlm_loss = mlm_labels = None\n\n        # ========= MPM ========== \n        if use_mpm: \n            mpm_labels, ignore_masks = self.get_pseudo_labels(batch)\n\n            mpm_loss, mpm_logits = self.compute_mpm_with_encoder_out(encoder_outputs=encoder_outputs_pos, \n                                                                     text_atts=text_atts, \n                                                                     soft_labels=mpm_labels, \n                                                                     ignore_masks=ignore_masks, \n                                                                     patch_masks=batch['mpm_mask']\n                                                                    )\n\n        else:\n            mpm_loss = mpm_logits = mpm_labels =  None\n\n        return dict(\n            itc_loss=vtc_loss,\n            mlm_scores=mlm_logits,  # (B, Lt, vocab_size),  only text part\n            mlm_loss=mlm_loss,  # (BxLt)\n            mlm_labels=mlm_labels,  # (1, Lt), with -100 indicates ignored positions\n            itm_scores=vtm_logits,  # (B, 2)\n            itm_loss=vtm_loss,  # (1, )\n            itm_labels=vtm_labels,  # (B, )\n            mpm_loss=mpm_loss,\n            mpm_logits=mpm_logits,\n            mpm_labels=mpm_labels\n        )\n\n\n    def _forward_visual_embeds(self, visual_inputs):\n        b, t, c, h, w = visual_inputs.shape\n        # timeSformer asks for (b, c, t, h, w) as input.\n        # image features\n        visual_inputs = visual_inputs.transpose(1, 2)\n\n        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)\n\n        return video_embeds\n\n    def _forward_text_feats(self, batch):\n        # text features\n        text_output = self.text_encoder.bert(batch['text_input_ids'], \n                                             attention_mask=batch['text_input_mask'],                      \n                                             return_dict = True, \n                                             mode = 'text'\n                                            )\n\n        text_embeds = text_output.last_hidden_state # b, Lt, fsz=768\n        text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 \n\n        return text_embeds, text_feat\n\n    def compute_mpm_with_encoder_out(self, encoder_outputs, text_atts, soft_labels, ignore_masks, patch_masks):\n        txt_len = text_atts.shape[1]\n        # adding one to ignore visual cls tokens\n        visual_output = encoder_outputs.last_hidden_state[:, txt_len+1:]\n\n        bsz, h, w = patch_masks.shape\n        patch_masks_flatten_inverted = (1 - patch_masks.view(bsz, -1)).unsqueeze(-1)\n\n        # mean embeds of masked visual regions\n        num_masked_patches = torch.sum(patch_masks_flatten_inverted.squeeze(-1), dim=-1, keepdim=True)\n\n        masked_visual_embeds = patch_masks_flatten_inverted * visual_output\n        masked_visual_embeds = torch.sum(masked_visual_embeds, dim=1)\n        masked_visual_embeds /= num_masked_patches\n\n        # loss\n        mpm_logits = self.mpm_head(masked_visual_embeds)\n\n        cross_entropy = -torch.sum(F.log_softmax(mpm_logits, dim=1) * soft_labels, dim=1)\n        cross_entropy[ignore_masks] = 0.\n\n        mpm_loss = torch.sum(cross_entropy) / (bsz - torch.sum(ignore_masks))\n\n        return mpm_loss, mpm_logits \n\n    def compute_mpm(self, text_embeds, text_atts, image_embeds, image_atts, soft_labels, ignore_masks, patch_masks):\n        # forward cross-encoder\n        attention_mask = torch.cat([text_atts, image_atts], dim=1)\n        embedding_output = torch.cat([text_embeds, image_embeds], dim=1)\n\n        encoder_outputs = self.text_encoder.bert(encoder_embeds=embedding_output,\n                                                 attention_mask=attention_mask,\n                                                 return_dict=True,\n                                                 mode='fusion'\n                                                )\n\n        txt_len = text_atts.shape[1]\n        # adding one to ignore visual cls tokens\n        visual_output = encoder_outputs.last_hidden_state[:, txt_len+1:]\n\n        bsz, h, w = patch_masks.shape\n        patch_masks_flatten_inverted = (1 - patch_masks.view(bsz, -1)).unsqueeze(-1)\n\n        # mean embeds of masked visual regions\n        num_masked_patches = torch.sum(patch_masks_flatten_inverted.squeeze(-1), dim=-1, keepdim=True)\n\n        masked_visual_embeds = patch_masks_flatten_inverted * visual_output\n        masked_visual_embeds = torch.sum(masked_visual_embeds, dim=1)\n        masked_visual_embeds /= num_masked_patches\n\n        # loss\n        mpm_logits = self.mpm_head(masked_visual_embeds)\n\n        cross_entropy = -torch.sum(F.log_softmax(mpm_logits, dim=1) * soft_labels, dim=1)\n        cross_entropy[ignore_masks] = 0.\n\n        mpm_loss = torch.sum(cross_entropy) / (bsz - torch.sum(ignore_masks))\n\n        return mpm_loss, mpm_logits \n\n    def compute_vtm(self, text_embeds, text_atts, video_embeds, video_atts, sim_v2t, sim_t2v, return_encoder_out=False):\n        device = text_embeds.device\n\n        # ====== positive pairs =======\n        attention_mask = torch.cat([text_atts, video_atts], dim=1)\n        embedding_output_pos = torch.cat([text_embeds, video_embeds], dim=1)\n\n        encoder_outputs_pos = self.text_encoder.bert(encoder_embeds=embedding_output_pos,\n                                                     attention_mask=attention_mask,\n                                                     return_dict=True,\n                                                     mode='fusion'\n                                                    )\n\n        # ====== negative pairs =======\n        bs = text_embeds.shape[0] \n\n        local_rank = hvd.local_rank()\n        b_start, b_end = bs * local_rank, bs * (local_rank + 1)\n\n        with torch.no_grad():\n            weights_i2t = sim_v2t[:,b_start:b_end]\n            weights_t2i = sim_t2v[:,b_start:b_end]\n   \n            # never select self as negative\n            weights_i2t.fill_diagonal_(-np.Inf)\n            weights_t2i.fill_diagonal_(-np.Inf)\n\n            weights_i2t = F.softmax(weights_i2t, dim=1)\n            weights_t2i = F.softmax(weights_t2i, dim=1)\n\n        # select a negative image for each text\n        # FIXME to optimize using indexing operations\n        video_embeds_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_t2i[b], 1).item()\n            video_embeds_neg.append(video_embeds[neg_idx])\n        video_embeds_neg = torch.stack(video_embeds_neg,dim=0)   \n\n        # select a negative text for each image\n        text_embeds_neg = []\n        text_atts_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_i2t[b], 1).item()\n            text_embeds_neg.append(text_embeds[neg_idx])\n            text_atts_neg.append(text_atts[neg_idx])\n\n        text_embeds_neg = torch.stack(text_embeds_neg,dim=0)   \n        text_atts_neg = torch.stack(text_atts_neg,dim=0)      \n\n        text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0)     \n        text_atts_all = torch.cat([text_atts, text_atts_neg],dim=0)     \n\n        video_embeds_all = torch.cat([video_embeds_neg,video_embeds],dim=0)\n        video_atts_all = torch.cat([video_atts,video_atts],dim=0)\n\n        attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1)\n        embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1)\n\n        # forward negative pairs via cross encoder\n        encoder_outputs_neg = self.text_encoder.bert(encoder_embeds=embedding_output_all,\n                                                     attention_mask=attention_mask_all,\n                                                     return_dict=True,\n                                                     mode='fusion'\n                                                    )\n\n        vl_embeddings = torch.cat([encoder_outputs_pos.last_hidden_state[:,0,:], \n                                   encoder_outputs_neg.last_hidden_state[:,0,:]],dim=0)\n        vtm_logits = self.itm_head(vl_embeddings)            \n\n        vtm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], dim=0).to(device)\n        vtm_loss = F.cross_entropy(vtm_logits, vtm_labels)     \n\n        if return_encoder_out:\n            return vtm_loss, vtm_logits, vtm_labels, encoder_outputs_pos \n        else:\n            return vtm_loss, vtm_logits, vtm_labels, None\n        \n    def compute_mlm(self, input_ids, text_input_mask, video_embeds, video_atts, mlm_labels):\n        # forward text features with masked_input_ids\n        text_output = self.text_encoder.bert(input_ids,\n                                             attention_mask=text_input_mask,\n                                             return_dict=True,\n                                             mode='text'\n                                            )\n        text_embeds = text_output.last_hidden_state\n\n        # forward cross-encoder\n        attention_mask = torch.cat([text_input_mask, video_atts], dim=1)\n        embedding_output = torch.cat([text_embeds, video_embeds], dim=1)\n\n        encoder_outputs = self.text_encoder.bert(encoder_embeds=embedding_output,\n                                                 attention_mask=attention_mask,\n                                                 return_dict=True,\n                                                 mode='fusion'\n                                                )\n\n        txt_len = text_input_mask.shape[1]\n        txt_output = encoder_outputs.last_hidden_state[:, :txt_len]\n\n        mlm_logits = self.text_encoder.cls(txt_output)\n\n        loss_fct = CrossEntropyLoss()\n        mlm_loss = loss_fct(mlm_logits.view(-1, self.bert_config.vocab_size), mlm_labels.view(-1))\n\n        return mlm_loss, mlm_logits, mlm_labels\n        \n    def load_separate_ckpt(self, visual_weights_path=None, bert_weights_path=None, prompter_weights_path=None):\n        if visual_weights_path:\n            self.visual_encoder.load_state_dict(visual_weights_path)\n\n        # [NOTE] BERT is initialized from huggingface pre-trained weights. \n        # if bert_weights_path:\n        #     load_multimodal_encoder_state_dict_with_mismatch(self.cross_encoder, bert_weights_path)\n        #     load_mlm_head_state_dict_with_mismatch(self.mlm_head, bert_weights_path)\n\n        # TODO make path configurable\n        if prompter_weights_path is not None:\n            self.prompter.load_pretrained_weights_without_prompts(prompter_weights_path)\n\n\nclass Prompter(AlproBaseModel):\n    def __init__(self, config, video_enc_cfg, input_format='RGB'):\n        super(Prompter, self).__init__(config, input_format=input_format, video_enc_cfg=video_enc_cfg)\n\n        # self.entity_num = 1000\n        self.entity_num = config.num_entities\n\n        self.register_buffer(\"video_prompt_feat\", torch.rand(self.entity_num, 256)) \n        self.register_buffer(\"image_prompt_feat\", torch.rand(self.entity_num, 256)) \n\n        self.prompt_initialized = False\n        # if the prob for the most likely entity is < 0.2, we just ignore it\n        self.ignore_threshold = 0.2\n\n\n    def load_pretrained_weights_without_prompts(self, ckpt_path):\n        LOGGER.info(\"Loading weights for teacher model.\")\n        loaded_state_dict = torch.load(ckpt_path, map_location='cpu')\n\n        loaded_keys = loaded_state_dict.keys()\n        model_keys = self.state_dict().keys()\n\n        load_not_in_model = [k for k in loaded_keys if k not in model_keys]\n        model_not_in_load = [k for k in model_keys if k not in loaded_keys]\n\n        if hvd.rank() == 0:\n            LOGGER.info(\"Keys in loaded but not in model:\")\n            LOGGER.info(f\"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}\")\n            LOGGER.info(\"Keys in model but not in loaded:\")\n            LOGGER.info(f\"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}\")\n\n        # FIXME a quick hack to avoid loading prompts\n        new_loaded_state_dict = dict()\n        for k in loaded_state_dict:\n            if not 'prompt_feat' in k:\n                new_loaded_state_dict[k] = loaded_state_dict[k]\n\n        loaded_state_dict = new_loaded_state_dict\n\n        self.load_state_dict(loaded_state_dict, strict=False)\n\n    def build_text_prompts(self, prompts):\n        \"\"\"\n        This function will be called, if no e2e.weights is provided.\n        In that case, \n        \"\"\"\n        assert not self.prompt_initialized, \"Repetitively building prompts?\"\n\n        if self.training:\n            self.eval()\n\n        video_prompt_feat_all = []\n        image_prompt_feat_all = []\n\n        with torch.no_grad():\n            # this configurable depending on the GPU memory limit\n            step_size = 10000\n\n            # ====== initializing video prompting ======\n            b_video, _ = prompts['batch_enc_video_prompts'].input_ids.shape\n\n            start = 0\n            end = start + step_size\n\n            while start < b_video:\n                video_prompt_output = self.text_encoder.bert(prompts['batch_enc_video_prompts'].input_ids[start:end].cuda(), \n                                                            attention_mask=prompts['batch_enc_video_prompts'].attention_mask[start:end].cuda(),                      \n                                                            return_dict=True, \n                                                            mode='text'\n                                                            )\n\n                video_prompt_embeds = video_prompt_output.last_hidden_state # b, Lt, fsz=768\n                video_prompt_feat = F.normalize(self.text_proj(video_prompt_embeds[:,0,:]),dim=-1)                 \n\n                # collecting\n                video_prompt_feat_all.append(video_prompt_feat)\n            \n                start += step_size\n                end += step_size\n\n            # average ensembling\n            video_prompt_feat = torch.cat(video_prompt_feat_all, dim=0)\n            video_num_templates = int(video_prompt_feat.shape[0] / self.entity_num)\n\n            video_prompt_feat = torch.stack(video_prompt_feat.chunk(video_num_templates), dim=1)\n            video_prompt_feat = torch.mean(video_prompt_feat, dim=1)\n            self.video_prompt_feat = video_prompt_feat\n\n            # ====== initializing image prompting ======\n            b_image, _ = prompts['batch_enc_image_prompts'].input_ids.shape\n\n            start = 0\n            end = start + step_size\n\n            while start < b_image:\n                # image prompts\n                image_prompt_output = self.text_encoder.bert(prompts['batch_enc_image_prompts'].input_ids[start:end].cuda(), \n                                                            attention_mask=prompts['batch_enc_image_prompts'].attention_mask[start:end].cuda(),                      \n                                                            return_dict = True, \n                                                            mode = 'text'\n                                                            )\n\n                image_prompt_embeds = image_prompt_output.last_hidden_state # b, Lt, fsz=768\n                image_prompt_feat = F.normalize(self.text_proj(image_prompt_embeds[:,0,:]),dim=-1)                 \n\n                # collecting\n                image_prompt_feat_all.append(image_prompt_feat)\n\n                start += step_size\n                end += step_size\n\n            image_prompt_feat = torch.cat(image_prompt_feat_all, dim=0)\n            image_num_templates = int(image_prompt_feat.shape[0] / self.entity_num)\n\n            image_prompt_feat = torch.stack(image_prompt_feat.chunk(image_num_templates), dim=1)\n            image_prompt_feat = torch.mean(image_prompt_feat, dim=1)\n            self.image_prompt_feat = image_prompt_feat\n\n        self.prompt_initialized = True\n\n    def _forward_visual_embeds(self, visual_inputs):\n        b, t, c, h, w = visual_inputs.shape\n        # timeSformer asks for (b, c, t, h, w) as input.\n        # image features\n        visual_inputs = visual_inputs.transpose(1, 2)\n\n        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)\n\n        assert self.itc_token_type == 'cls', 'Expecting CLS token for ITC, found {}'.format(self.itc_token_type)\n        if self.itc_token_type == 'cls':\n            video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  \n        else:\n            raise NotImplementedError(\"itc_type_type must be one of ['mean', 'cls', 'mil'], found {}\".format(self.itc_token_type))\n        \n        return video_embeds, video_feat\n\n    def _compute_soft_labels(self, sim_vp_masked):\n        soft_labels = nn.Softmax(dim=1)(sim_vp_masked)\n        ignore_masks = torch.max(sim_vp_masked, dim=1)[1] < self.ignore_threshold\n\n        return soft_labels, ignore_masks\n\n    def get_pseudo_labels(self, batch):\n        if self.training:\n            self.eval()\n\n        with torch.no_grad():\n            masked_visual_inputs = batch['crop_visual_inputs']\n\n            _, masked_image_feat = self._forward_visual_embeds(masked_visual_inputs)\n\n            if batch['type'] == 'video':\n                prompt_feat = self.video_prompt_feat\n            else:\n                prompt_feat = self.image_prompt_feat\n\n            # visual feat to video prompts\n            # masked visual feat to video prompts\n            sim_masked = masked_image_feat @ prompt_feat.t() / self.temp \n\n            pseudo_labels, ignore_masks = self._compute_soft_labels(sim_masked)\n\n        return pseudo_labels, ignore_masks\n\n    def forward(self, batch):\n        visual_inputs = batch['visual_inputs']\n\n        device = visual_inputs.device\n        b, t, c, h, w = visual_inputs.shape\n\n        # forward image and text features\n        # feats are normalized embeds\n        video_embeds, video_feat, text_embeds, text_feat = self.forward_feats(batch)\n        image_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)\n\n        # ========== (in-batch) ITC loss ==========\n        gathered_image_feats = hvd.allgather(video_feat)\n        gathered_text_feats = hvd.allgather(text_feat)\n\n        assert self.itc_token_type == 'cls', 'Expecting CLS token for ITC, found {}'.format(self.itc_token_type)\n\n        sim_v2t = video_feat @ gathered_text_feats.t() / self.temp \n        sim_t2v = text_feat @ gathered_image_feats.t() / self.temp \n                             \n        # [IMPORTANT] be very careful when initializing the GT sim_i2t \n        # allgather return the concatenated features in the order of local_rank()\n        sim_targets = torch.zeros_like(sim_v2t)\n\n        local_rank = hvd.local_rank()\n        b_start, b_end = b * local_rank, b * (local_rank + 1)\n        sim_targets[:, b_start: b_end] = torch.eye(b)\n\n        sim_v2t_scores = F.log_softmax(sim_v2t, dim=1)\n        sim_t2v_scores = F.log_softmax(sim_t2v, dim=1)\n\n        loss_v2t = -torch.sum(sim_v2t_scores * sim_targets,dim=1).mean()\n        loss_t2v = -torch.sum(sim_t2v_scores * sim_targets,dim=1).mean() \n\n        vtc_loss = (loss_v2t+loss_t2v) / 2\n\n        return dict(\n            itc_loss=vtc_loss,\n            itc_labels=torch.max(sim_targets, dim=1)[1],\n            i2t_scores=sim_v2t_scores,\n            t2i_scores=sim_t2v_scores\n        )\n\n\n    def forward_feats(self, batch):\n        with torch.no_grad():\n            self.temp.clamp_(0.001,0.5)\n\n        visual_inputs = batch['visual_inputs']\n\n        b, t, c, h, w = visual_inputs.shape\n        # timeSformer asks for (b, c, t, h, w) as input.\n        # image features\n        visual_inputs = visual_inputs.transpose(1, 2)\n\n        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)\n\n        assert self.itc_token_type == 'cls', 'Expecting CLS token for ITC, found {}'.format(self.itc_token_type)\n        if self.itc_token_type == 'cls':\n            video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  \n        else:\n            raise NotImplementedError(\"itc_type_type must be one of ['mean', 'cls', 'mil'], found {}\".format(self.itc_token_type))\n\n        # text features\n        text_output = self.text_encoder.bert(batch['text_input_ids'], \n                                             attention_mask=batch['text_input_mask'],                      \n                                             return_dict = True, \n                                             mode = 'text'\n                                            )\n\n        text_embeds = text_output.last_hidden_state # b, Lt, fsz=768\n\n        if self.itc_token_type == 'cls':\n            text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 \n        else:\n            raise NotImplementedError(\"itc_token_type must be one of ['mean', 'cls', 'mil'], found {}\".format(self.itc_token_type))\n\n        return video_embeds, video_feat, text_embeds, text_feat\n\n\nclass AlproForSequenceClassification(AlproBaseModel):\n    def __init__(self, config, video_enc_cfg, input_format='RGB'):\n        super(AlproForSequenceClassification, self).__init__(config, video_enc_cfg=video_enc_cfg)\n\n        self.text_encoder = BertModel.from_pretrained('bert-base-uncased', config=self.bert_config, add_pooling_layer=False)      \n\n        self.classifier = nn.Sequential(\n            nn.Linear(config.hidden_size,\n                      config.hidden_size * 2),\n            nn.ReLU(True),\n            nn.Linear(config.hidden_size * 2, config.num_labels)\n        )\n\n    # def forward(self, image, text, targets, alpha=0, train=True):\n    def forward(self, batch):\n        visual_inputs = batch['visual_inputs']\n        targets = batch['labels']\n\n        device = visual_inputs.device\n\n        # forward text\n        text_input_mask = batch['text_input_mask']\n        text_output = self.text_encoder(batch['text_input_ids'],\n                                        attention_mask=text_input_mask,\n                                        return_dict=True,\n                                        mode='text'\n                                        )\n        text_embeds = text_output.last_hidden_state\n\n        # forward visual\n        b, t, c, h, w = visual_inputs.shape\n        # timeSformer asks for (b, c, t, h, w) as input.\n        visual_inputs = visual_inputs.transpose(1, 2)\n\n        image_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)\n        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(device)\n\n        # forward cross-encoder\n        attention_mask = torch.cat([text_input_mask, image_atts], dim=1)\n        embedding_output = torch.cat([text_embeds, image_embeds], dim=1)\n\n        output = self.text_encoder(encoder_embeds=embedding_output,\n                                attention_mask=attention_mask,\n                                return_dict=True,\n                                mode='fusion'\n                                )\n\n        prediction = self.classifier(output.last_hidden_state[:,0,:])                \n        if targets is not None:\n            loss = F.cross_entropy(prediction, targets)                \n        else: # evaluation mode\n            loss = 0\n\n        return dict(loss=loss,\n                    logits=prediction\n                    )\n            \n\n    def forward_inference(self, batch):\n        visual_inputs = batch['visual_inputs']\n        device = visual_inputs.device\n\n        # forward text\n        text_input_mask = batch['text_input_mask']\n        text_output = self.text_encoder.bert(batch['text_input_ids'],\n                                             attention_mask=text_input_mask,\n                                             return_dict=True,\n                                             mode='text'\n                                            )\n        text_embeds = text_output.last_hidden_state\n\n        # forward visual\n        b, t, c, h, w = visual_inputs.shape\n        # timeSformer asks for (b, c, t, h, w) as input.\n        visual_inputs = visual_inputs.transpose(1, 2)\n\n        image_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)\n        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(device)\n\n        # forward cross-encoder\n        attention_mask = torch.cat([text_input_mask, image_atts], dim=1)\n        embedding_output = torch.cat([text_embeds, image_embeds], dim=1)\n\n        output = self.text_encoder.bert(encoder_embeds=embedding_output,\n                                        attention_mask=attention_mask,\n                                        return_dict=True,\n                                        mode='fusion'\n                                    )\n\n        prediction = self.classifier(output.last_hidden_state[:,0,:])                \n\n        return prediction\n\n\nclass AlproForVideoTextRetrieval(AlproBaseModel):\n    \"\"\"\n    \"\"\"\n    def __init__(self, config, video_enc_cfg, input_format='RGB'):\n        super(AlproForVideoTextRetrieval, self).__init__(config, input_format=input_format, video_enc_cfg=video_enc_cfg)\n\n    def forward(self, batch):\n        with torch.no_grad():\n            self.temp.clamp_(0.001,0.5)\n\n        visual_inputs = batch['visual_inputs']\n        text_input_mask = batch['text_input_mask']\n        text_input_ids = batch['text_input_ids']\n\n        device = visual_inputs.device\n\n        b, t, c, h, w = visual_inputs.shape\n        # timeSformer asks for (b, c, t, h, w) as input.\n        # visual embeddings\n        visual_inputs = visual_inputs.transpose(1, 2)\n\n        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)\n        # image_embeds = image_embeds.repeat(text_input_mask.shape[0], 1, 1)\n        video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  \n\n        video_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)\n\n        # text embeddings\n        text_output = self.text_encoder.bert(text_input_ids,\n                                             attention_mask=text_input_mask,\n                                             return_dict=True,\n                                             mode='text'\n                                            )\n        text_embeds = text_output.last_hidden_state\n        text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 \n\n        # ========== (in-batch) ITC loss ==========\n        gathered_video_feats = hvd.allgather(video_feat)\n        gathered_text_feats = hvd.allgather(text_feat)\n\n        sim_v2t = video_feat @ gathered_text_feats.t() / self.temp \n        sim_t2v = text_feat @ gathered_video_feats.t() / self.temp \n\n        sim_targets = torch.zeros_like(sim_v2t)\n\n        local_rank = hvd.local_rank()\n        b_start, b_end = b * local_rank, b * (local_rank + 1)\n        sim_targets[:, b_start: b_end] = torch.eye(b)\n\n        loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1)*sim_targets,dim=1).mean()\n        loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1)*sim_targets,dim=1).mean() \n\n        vtc_loss = (loss_v2t+loss_t2v) / 2\n\n        # ========= ITM ==========\n        text_atts = batch['text_input_mask']\n\n        # non-masked text and non-masked image \n        vtm_loss, vtm_logits, vtm_labels = self.compute_vtm(text_embeds=text_embeds, \n                                                            text_atts=text_atts, \n                                                            image_embeds=video_embeds, \n                                                            image_atts=video_atts, \n                                                            sim_i2t=sim_v2t.clone(), # for hard mining\n                                                            sim_t2i=sim_t2v.clone()  # for hard mining\n                                                           )\n\n        return dict(\n            itm_scores=vtm_logits,\n            itm_loss=vtm_loss,\n            itm_labels=vtm_labels,\n            itc_loss=vtc_loss\n        )\n    \n    def compute_vtm(self, text_embeds, text_atts, image_embeds, image_atts, sim_i2t, sim_t2i):\n        device = text_embeds.device\n\n        # ====== positive pairs =======\n        attention_mask = torch.cat([text_atts, image_atts], dim=1)\n        embedding_output_pos = torch.cat([text_embeds, image_embeds], dim=1)\n\n        encoder_outputs_pos = self.text_encoder.bert(encoder_embeds=embedding_output_pos,\n                                                     attention_mask=attention_mask,\n                                                     return_dict=True,\n                                                     mode='fusion'\n                                                    )\n\n        # ====== negative pairs =======\n        bs = text_embeds.shape[0] \n\n        local_rank = hvd.local_rank()\n        b_start, b_end = bs * local_rank, bs * (local_rank + 1)\n\n        with torch.no_grad():\n            weights_v2t = sim_i2t[:,b_start:b_end]\n            weights_t2v = sim_t2i[:,b_start:b_end]\n   \n            # never select self as negative\n            weights_v2t.fill_diagonal_(-np.Inf)\n            weights_t2v.fill_diagonal_(-np.Inf)\n\n            weights_v2t = F.softmax(weights_v2t, dim=1)\n            weights_t2v = F.softmax(weights_t2v, dim=1)\n\n        # select a negative image for each text\n        # FIXME to optimize using indexing operations\n        image_embeds_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_t2v[b], 1).item()\n            image_embeds_neg.append(image_embeds[neg_idx])\n        image_embeds_neg = torch.stack(image_embeds_neg,dim=0)   \n\n        # select a negative text for each image\n        text_embeds_neg = []\n        text_atts_neg = []\n        for b in range(bs):\n            neg_idx = torch.multinomial(weights_v2t[b], 1).item()\n            text_embeds_neg.append(text_embeds[neg_idx])\n            text_atts_neg.append(text_atts[neg_idx])\n\n        text_embeds_neg = torch.stack(text_embeds_neg,dim=0)   \n        text_atts_neg = torch.stack(text_atts_neg,dim=0)      \n\n        text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0)     \n        text_atts_all = torch.cat([text_atts, text_atts_neg],dim=0)     \n\n        video_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)\n        video_atts_all = torch.cat([image_atts,image_atts],dim=0)\n\n        attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1)\n        embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1)\n\n        # forward negative pairs via cross encoder\n        encoder_outputs_neg = self.text_encoder.bert(encoder_embeds=embedding_output_all,\n                                                     attention_mask=attention_mask_all,\n                                                     return_dict=True,\n                                                     mode='fusion'\n                                                    )\n\n        vl_embeddings = torch.cat([encoder_outputs_pos.last_hidden_state[:,0,:], \n                                   encoder_outputs_neg.last_hidden_state[:,0,:]],dim=0)\n        vtm_logits = self.itm_head(vl_embeddings)            \n\n        vtm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], dim=0).to(device)\n        vtm_loss = F.cross_entropy(vtm_logits, vtm_labels)     \n\n        return vtm_loss, vtm_logits, vtm_labels \n\n    def forward_inference(self, batch):\n        visual_inputs = batch['visual_inputs']\n        text_input_mask = batch['text_input_mask']\n        text_input_ids = batch['text_input_ids']\n\n        device = visual_inputs.device\n\n        b, t, c, h, w = visual_inputs.shape\n        # timeSformer asks for (b, c, t, h, w) as input.\n        visual_inputs = visual_inputs.transpose(1, 2)\n\n        video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)\n        video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)  \n\n        video_embeds = video_embeds.repeat(text_input_mask.shape[0], 1, 1)\n        # image_feat = image_feat.repeat(text_input_mask.shape[0], 1)\n\n        video_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)\n        text_output = self.text_encoder.bert(text_input_ids,\n                                             attention_mask=text_input_mask,\n                                             return_dict=True,\n                                             mode='text'\n                                            )\n        text_embeds = text_output.last_hidden_state\n        text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)                 \n\n        vtc_sim_scores = video_feat @ text_feat.t() / self.temp\n\n        attention_mask = torch.cat([text_input_mask, video_atts], dim=1)\n        embedding_output = torch.cat([text_embeds, video_embeds], dim=1)\n\n        encoder_outputs = self.text_encoder.bert(encoder_embeds=embedding_output,\n                                                 attention_mask=attention_mask,\n                                                 return_dict=True,\n                                                 mode='fusion'\n                                                )\n\n        vl_embeddings = encoder_outputs.last_hidden_state[:,0,:]\n        logits = self.itm_head(vl_embeddings)\n\n        return dict(logits=logits, itc_scores=vtc_sim_scores)\n\n"
  },
  {
    "path": "src/modeling/timesformer/__init__.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n\n# from .build import MODEL_REGISTRY, build_model  # noqa\n# from .custom_video_model_builder import *  # noqa\n# from .video_model_builder import ResNet, SlowFast # noqa\n"
  },
  {
    "path": "src/modeling/timesformer/conv2d_same.py",
    "content": "# Copyright 2020 Ross Wightman\n# Conv2d w/ Same Padding\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Tuple, Optional\n\nimport math\nfrom typing import List, Tuple\n#from .padding import pad_same, get_padding_value\n\n# Dynamically pad input x with 'SAME' padding for conv with specified args\ndef pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):\n    ih, iw = x.size()[-2:]\n    pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])\n    if pad_h > 0 or pad_w > 0:\n        x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)\n    return x\n\n# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution\ndef get_same_padding(x: int, k: int, s: int, d: int):\n    return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)\n\ndef get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:\n    dynamic = False\n    if isinstance(padding, str):\n        # for any string padding, the padding will be calculated for you, one of three ways\n        padding = padding.lower()\n        if padding == 'same':\n            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact\n            if is_static_pad(kernel_size, **kwargs):\n                # static case, no extra overhead\n                padding = get_padding(kernel_size, **kwargs)\n            else:\n                # dynamic 'SAME' padding, has runtime/GPU memory overhead\n                padding = 0\n                dynamic = True\n        elif padding == 'valid':\n            # 'VALID' padding, same as padding=0\n            padding = 0\n        else:\n            # Default to PyTorch style 'same'-ish symmetric padding\n            padding = get_padding(kernel_size, **kwargs)\n    return padding, dynamic\n\ndef conv2d_same(\n        x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),\n        padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):\n    x = pad_same(x, weight.shape[-2:], stride, dilation)\n    return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)\n\n\nclass Conv2dSame(nn.Conv2d):\n    \"\"\" Tensorflow like 'SAME' convolution wrapper for 2D convolutions\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 padding=0, dilation=1, groups=1, bias=True):\n        super(Conv2dSame, self).__init__(\n            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)\n\n    def forward(self, x):\n        return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)\n\n\ndef create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):\n    padding = kwargs.pop('padding', '')\n    kwargs.setdefault('bias', False)\n    padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)\n    if is_dynamic:\n        return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)\n    else:\n        return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)\n"
  },
  {
    "path": "src/modeling/timesformer/features.py",
    "content": "# Copyright 2020 Ross Wightman\n\nfrom collections import OrderedDict, defaultdict\nfrom copy import deepcopy\nfrom functools import partial\nfrom typing import Dict, List, Tuple\n\nimport torch\nimport torch.nn as nn\n\n\nclass FeatureInfo:\n\n    def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):\n        prev_reduction = 1\n        for fi in feature_info:\n            # sanity check the mandatory fields, there may be additional fields depending on the model\n            assert 'num_chs' in fi and fi['num_chs'] > 0\n            assert 'reduction' in fi and fi['reduction'] >= prev_reduction\n            prev_reduction = fi['reduction']\n            assert 'module' in fi\n        self.out_indices = out_indices\n        self.info = feature_info\n\n    def from_other(self, out_indices: Tuple[int]):\n        return FeatureInfo(deepcopy(self.info), out_indices)\n\n    def get(self, key, idx=None):\n        \"\"\" Get value by key at specified index (indices)\n        if idx == None, returns value for key at each output index\n        if idx is an integer, return value for that feature module index (ignoring output indices)\n        if idx is a list/tupple, return value for each module index (ignoring output indices)\n        \"\"\"\n        if idx is None:\n            return [self.info[i][key] for i in self.out_indices]\n        if isinstance(idx, (tuple, list)):\n            return [self.info[i][key] for i in idx]\n        else:\n            return self.info[idx][key]\n\n    def get_dicts(self, keys=None, idx=None):\n        \"\"\" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)\n        \"\"\"\n        if idx is None:\n            if keys is None:\n                return [self.info[i] for i in self.out_indices]\n            else:\n                return [{k: self.info[i][k] for k in keys} for i in self.out_indices]\n        if isinstance(idx, (tuple, list)):\n            return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]\n        else:\n            return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}\n\n    def channels(self, idx=None):\n        \"\"\" feature channels accessor\n        \"\"\"\n        return self.get('num_chs', idx)\n\n    def reduction(self, idx=None):\n        \"\"\" feature reduction (output stride) accessor\n        \"\"\"\n        return self.get('reduction', idx)\n\n    def module_name(self, idx=None):\n        \"\"\" feature module name accessor\n        \"\"\"\n        return self.get('module', idx)\n\n    def __getitem__(self, item):\n        return self.info[item]\n\n    def __len__(self):\n        return len(self.info)\n\n\nclass FeatureHooks:\n    \"\"\" Feature Hook Helper\n    This module helps with the setup and extraction of hooks for extracting features from\n    internal nodes in a model by node name. This works quite well in eager Python but needs\n    redesign for torcscript.\n    \"\"\"\n\n    def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):\n        # setup feature hooks\n        modules = {k: v for k, v in named_modules}\n        for i, h in enumerate(hooks):\n            hook_name = h['module']\n            m = modules[hook_name]\n            hook_id = out_map[i] if out_map else hook_name\n            hook_fn = partial(self._collect_output_hook, hook_id)\n            hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type\n            if hook_type == 'forward_pre':\n                m.register_forward_pre_hook(hook_fn)\n            elif hook_type == 'forward':\n                m.register_forward_hook(hook_fn)\n            else:\n                assert False, \"Unsupported hook type\"\n        self._feature_outputs = defaultdict(OrderedDict)\n\n    def _collect_output_hook(self, hook_id, *args):\n        x = args[-1]  # tensor we want is last argument, output for fwd, input for fwd_pre\n        if isinstance(x, tuple):\n            x = x[0]  # unwrap input tuple\n        self._feature_outputs[x.device][hook_id] = x\n\n    def get_output(self, device) -> Dict[str, torch.tensor]:\n        output = self._feature_outputs[device]\n        self._feature_outputs[device] = OrderedDict()  # clear after reading\n        return output\n\n\ndef _module_list(module, flatten_sequential=False):\n    # a yield/iter would be better for this but wouldn't be compatible with torchscript\n    ml = []\n    for name, module in module.named_children():\n        if flatten_sequential and isinstance(module, nn.Sequential):\n            # first level of Sequential containers is flattened into containing model\n            for child_name, child_module in module.named_children():\n                combined = [name, child_name]\n                ml.append(('_'.join(combined), '.'.join(combined), child_module))\n        else:\n            ml.append((name, name, module))\n    return ml\n\n\ndef _get_feature_info(net, out_indices):\n    feature_info = getattr(net, 'feature_info')\n    if isinstance(feature_info, FeatureInfo):\n        return feature_info.from_other(out_indices)\n    elif isinstance(feature_info, (list, tuple)):\n        return FeatureInfo(net.feature_info, out_indices)\n    else:\n        assert False, \"Provided feature_info is not valid\"\n\n\ndef _get_return_layers(feature_info, out_map):\n    module_names = feature_info.module_name()\n    return_layers = {}\n    for i, name in enumerate(module_names):\n        return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]\n    return return_layers\n\n\nclass FeatureDictNet(nn.ModuleDict):\n    \"\"\" Feature extractor with OrderedDict return\n    Wrap a model and extract features as specified by the out indices, the network is\n    partially re-built from contained modules.\n    There is a strong assumption that the modules have been registered into the model in the same\n    order as they are used. There should be no reuse of the same nn.Module more than once, including\n    trivial modules like `self.relu = nn.ReLU`.\n    Only submodules that are directly assigned to the model class (`model.feature1`) or at most\n    one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.\n    All Sequential containers that are directly assigned to the original model will have their\n    modules assigned to this module with the name `model.features.1` being changed to `model.features_1`\n    Arguments:\n        model (nn.Module): model from which we will extract the features\n        out_indices (tuple[int]): model output indices to extract features for\n        out_map (sequence): list or tuple specifying desired return id for each out index,\n            otherwise str(index) is used\n        feature_concat (bool): whether to concatenate intermediate features that are lists or tuples\n            vs select element [0]\n        flatten_sequential (bool): whether to flatten sequential modules assigned to model\n    \"\"\"\n    def __init__(\n            self, model,\n            out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):\n        super(FeatureDictNet, self).__init__()\n        self.feature_info = _get_feature_info(model, out_indices)\n        self.concat = feature_concat\n        self.return_layers = {}\n        return_layers = _get_return_layers(self.feature_info, out_map)\n        modules = _module_list(model, flatten_sequential=flatten_sequential)\n        remaining = set(return_layers.keys())\n        layers = OrderedDict()\n        for new_name, old_name, module in modules:\n            layers[new_name] = module\n            if old_name in remaining:\n                # return id has to be consistently str type for torchscript\n                self.return_layers[new_name] = str(return_layers[old_name])\n                remaining.remove(old_name)\n            if not remaining:\n                break\n        assert not remaining and len(self.return_layers) == len(return_layers), \\\n            f'Return layers ({remaining}) are not present in model'\n        self.update(layers)\n\n    def _collect(self, x) -> (Dict[str, torch.Tensor]):\n        out = OrderedDict()\n        for name, module in self.items():\n            x = module(x)\n            if name in self.return_layers:\n                out_id = self.return_layers[name]\n                if isinstance(x, (tuple, list)):\n                    # If model tap is a tuple or list, concat or select first element\n                    # FIXME this may need to be more generic / flexible for some nets\n                    out[out_id] = torch.cat(x, 1) if self.concat else x[0]\n                else:\n                    out[out_id] = x\n        return out\n\n    def forward(self, x) -> Dict[str, torch.Tensor]:\n        return self._collect(x)\n\n\nclass FeatureListNet(FeatureDictNet):\n    \"\"\" Feature extractor with list return\n    See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.\n    In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.\n    \"\"\"\n    def __init__(\n            self, model,\n            out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):\n        super(FeatureListNet, self).__init__(\n            model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,\n            flatten_sequential=flatten_sequential)\n\n    def forward(self, x) -> (List[torch.Tensor]):\n        return list(self._collect(x).values())\n\n\nclass FeatureHookNet(nn.ModuleDict):\n    \"\"\" FeatureHookNet\n    Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.\n    If `no_rewrite` is True, features are extracted via hooks without modifying the underlying\n    network in any way.\n    If `no_rewrite` is False, the model will be re-written as in the\n    FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.\n    FIXME this does not currently work with Torchscript, see FeatureHooks class\n    \"\"\"\n    def __init__(\n            self, model,\n            out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,\n            feature_concat=False, flatten_sequential=False, default_hook_type='forward'):\n        super(FeatureHookNet, self).__init__()\n        assert not torch.jit.is_scripting()\n        self.feature_info = _get_feature_info(model, out_indices)\n        self.out_as_dict = out_as_dict\n        layers = OrderedDict()\n        hooks = []\n        if no_rewrite:\n            assert not flatten_sequential\n            if hasattr(model, 'reset_classifier'):  # make sure classifier is removed?\n                model.reset_classifier(0)\n            layers['body'] = model\n            hooks.extend(self.feature_info.get_dicts())\n        else:\n            modules = _module_list(model, flatten_sequential=flatten_sequential)\n            remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type\n                         for f in self.feature_info.get_dicts()}\n            for new_name, old_name, module in modules:\n                layers[new_name] = module\n                for fn, fm in module.named_modules(prefix=old_name):\n                    if fn in remaining:\n                        hooks.append(dict(module=fn, hook_type=remaining[fn]))\n                        del remaining[fn]\n                if not remaining:\n                    break\n            assert not remaining, f'Return layers ({remaining}) are not present in model'\n        self.update(layers)\n        self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)\n\n    def forward(self, x):\n        for name, module in self.items():\n            x = module(x)\n        out = self.hooks.get_output(x.device)\n        return out if self.out_as_dict else list(out.values())\n"
  },
  {
    "path": "src/modeling/timesformer/helpers.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n# Copyright 2020 Ross Wightman\n# Modified model creation / weight loading / state_dict helpers\n\nimport logging\nimport os\nimport sys\nimport math\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom typing import Callable\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\nimport torch.nn.functional as F\n\nfrom src.modeling.timesformer.features import FeatureListNet, FeatureDictNet, FeatureHookNet\nfrom src.modeling.timesformer.conv2d_same import Conv2dSame\nfrom src.modeling.timesformer.linear import Linear\n\nfrom horovod import torch as hvd\n\n_logger = logging.getLogger()\n\ndef load_state_dict(checkpoint_path, use_ema=False):\n    if checkpoint_path and os.path.isfile(checkpoint_path):\n        checkpoint = torch.load(checkpoint_path, map_location='cpu')\n        state_dict_key = 'state_dict'\n        if isinstance(checkpoint, dict):\n            if use_ema and 'state_dict_ema' in checkpoint:\n                state_dict_key = 'state_dict_ema'\n        if state_dict_key and state_dict_key in checkpoint:\n            new_state_dict = OrderedDict()\n            for k, v in checkpoint[state_dict_key].items():\n                # strip `module.` prefix\n                name = k[7:] if k.startswith('module') else k\n                new_state_dict[name] = v\n            state_dict = new_state_dict\n        elif 'model_state' in checkpoint:\n            state_dict_key = 'model_state'\n            new_state_dict = OrderedDict()\n            for k, v in checkpoint[state_dict_key].items():\n                # strip `model.` prefix\n                name = k[6:] if k.startswith('model') else k\n                new_state_dict[name] = v\n            state_dict = new_state_dict\n        else:\n            state_dict = checkpoint\n        _logger.info(\"Loaded {} from checkpoint '{}'\".format(state_dict_key, checkpoint_path))\n        return state_dict\n    else:\n        _logger.error(\"No checkpoint found at '{}'\".format(checkpoint_path))\n        raise FileNotFoundError()\n\n\ndef load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):\n    state_dict = load_state_dict(checkpoint_path, use_ema)\n    model.load_state_dict(state_dict, strict=strict)\n\n\n# def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):\n#     resume_epoch = None\n    # if os.path.isfile(checkpoint_path):\n    #     checkpoint = torch.load(checkpoint_path, map_location='cpu')\n    #     if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:\n    #         if log_info:\n    #             _logger.info('Restoring model state from checkpoint...')\n    #         new_state_dict = OrderedDict()\n    #         for k, v in checkpoint['state_dict'].items():\n    #             name = k[7:] if k.startswith('module') else k\n    #             new_state_dict[name] = v\n    #         model.load_state_dict(new_state_dict)\n\n    #         if optimizer is not None and 'optimizer' in checkpoint:\n    #             if log_info:\n    #                 _logger.info('Restoring optimizer state from checkpoint...')\n    #             optimizer.load_state_dict(checkpoint['optimizer'])\n\n    #         if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:\n    #             if log_info:\n    #                 _logger.info('Restoring AMP loss scaler state from checkpoint...')\n    #             loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])\n\n    #         if 'epoch' in checkpoint:\n    #             resume_epoch = checkpoint['epoch']\n    #             if 'version' in checkpoint and checkpoint['version'] > 1:\n    #                 resume_epoch += 1  # start at the next epoch, old checkpoints incremented before save\n\n    #         if log_info:\n    #             _logger.info(\"Loaded checkpoint '{}' (epoch {})\".format(checkpoint_path, checkpoint['epoch']))\n    #     else:\n    #         model.load_state_dict(checkpoint)\n    #         if log_info:\n    #             _logger.info(\"Loaded checkpoint '{}'\".format(checkpoint_path))\n    #     return resume_epoch\n    # else:\n    #     _logger.error(\"No checkpoint found at '{}'\".format(checkpoint_path))\n    #     raise FileNotFoundError()\n\n\ndef load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_frames=8, num_patches=196, attention_type='divided_space_time', pretrained_model=\"\", strict=True):\n    if cfg is None:\n        cfg = getattr(model, 'default_cfg')\n    if cfg is None or 'url' not in cfg or not cfg['url']:\n        _logger.warning(\"Pretrained model URL is invalid, using random initialization.\")\n        return\n\n    if len(pretrained_model) == 0:\n        if cfg is None:\n            _logger.info(f\"loading from default config {model.default_cfg}.\")\n        state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')\n    else:\n       try:\n         state_dict = load_state_dict(pretrained_model)['model']\n       except:\n         state_dict = load_state_dict(pretrained_model)\n\n\n    if filter_fn is not None:\n        state_dict = filter_fn(state_dict)\n\n    if in_chans == 1:\n        conv1_name = cfg['first_conv']\n        _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)\n        conv1_weight = state_dict[conv1_name + '.weight']\n        conv1_type = conv1_weight.dtype\n        conv1_weight = conv1_weight.float()\n        O, I, J, K = conv1_weight.shape\n        if I > 3:\n            assert conv1_weight.shape[1] % 3 == 0\n            # For models with space2depth stems\n            conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)\n            conv1_weight = conv1_weight.sum(dim=2, keepdim=False)\n        else:\n            conv1_weight = conv1_weight.sum(dim=1, keepdim=True)\n        conv1_weight = conv1_weight.to(conv1_type)\n        state_dict[conv1_name + '.weight'] = conv1_weight\n    elif in_chans != 3:\n        conv1_name = cfg['first_conv']\n        conv1_weight = state_dict[conv1_name + '.weight']\n        conv1_type = conv1_weight.dtype\n        conv1_weight = conv1_weight.float()\n        O, I, J, K = conv1_weight.shape\n        if I != 3:\n            _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)\n            del state_dict[conv1_name + '.weight']\n            strict = False\n        else:\n            _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)\n            repeat = int(math.ceil(in_chans / 3))\n            conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]\n            conv1_weight *= (3 / float(in_chans))\n            conv1_weight = conv1_weight.to(conv1_type)\n            state_dict[conv1_name + '.weight'] = conv1_weight\n\n\n    classifier_name = cfg['classifier']\n    if num_classes == 1000 and cfg['num_classes'] == 1001:\n        # special case for imagenet trained models with extra background class in pretrained weights\n        classifier_weight = state_dict[classifier_name + '.weight']\n        state_dict[classifier_name + '.weight'] = classifier_weight[1:]\n        classifier_bias = state_dict[classifier_name + '.bias']\n        state_dict[classifier_name + '.bias'] = classifier_bias[1:]\n    elif num_classes != state_dict[classifier_name + '.weight'].size(0):\n        #print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True)\n        # completely discard fully connected for all other differences between pretrained and created model\n        del state_dict[classifier_name + '.weight']\n        del state_dict[classifier_name + '.bias']\n        strict = False\n\n\n    ## Resizing the positional embeddings in case they don't match\n    _logger.info(f\"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}\")\n    if num_patches + 1 != state_dict['pos_embed'].size(1):\n        pos_embed = state_dict['pos_embed']\n        cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1)\n        other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2)\n        new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')\n        new_pos_embed = new_pos_embed.transpose(1, 2)\n        new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)\n        state_dict['pos_embed'] = new_pos_embed\n\n    ## Resizing time embeddings in case they don't match\n    if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1):\n        _logger.info(f\"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}\")\n        time_embed = state_dict['time_embed'].transpose(1, 2)\n        new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest')\n        state_dict['time_embed'] = new_time_embed.transpose(1, 2)\n\n    ## Initializing temporal attention\n    if attention_type == 'divided_space_time':\n        new_state_dict = state_dict.copy()\n        for key in state_dict:\n            if 'blocks' in key and 'attn' in key:\n                new_key = key.replace('attn','temporal_attn')\n                if not new_key in state_dict:\n                   new_state_dict[new_key] = state_dict[key]\n                else:\n                   new_state_dict[new_key] = state_dict[new_key]\n            if 'blocks' in key and 'norm1' in key:\n                new_key = key.replace('norm1','temporal_norm1')\n                if not new_key in state_dict:\n                   new_state_dict[new_key] = state_dict[key]\n                else:\n                   new_state_dict[new_key] = state_dict[new_key]\n        state_dict = new_state_dict\n\n    ## Loading the weights\n    model.load_state_dict(state_dict, strict=False)\n\n\ndef load_pretrained_CLIP_ViT(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs):\n    if hvd.rank() == 0:\n        _logger.info(f\"Loading CLIP ViT-B/16 checkpoints.\")\n    loaded_state_dict = torch.load(pretrained_model) \n\n    ## Initializing temporal attention\n    new_state_dict = loaded_state_dict.copy()\n    for key in loaded_state_dict:\n        if 'blocks' in key and 'attn' in key:\n            new_key = key.replace('attn','temporal_attn')\n            if not new_key in loaded_state_dict:\n                new_state_dict[new_key] = loaded_state_dict[key]\n            else:\n                new_state_dict[new_key] = loaded_state_dict[new_key]\n        if 'blocks' in key and 'norm1' in key:\n            new_key = key.replace('norm1','temporal_norm1')\n            if not new_key in loaded_state_dict:\n                new_state_dict[new_key] = loaded_state_dict[key]\n            else:\n                new_state_dict[new_key] = loaded_state_dict[new_key]\n\n    loaded_state_dict = new_state_dict\n\n    loaded_keys = loaded_state_dict.keys()\n    model_keys = model.state_dict().keys()\n\n    load_not_in_model = [k for k in loaded_keys if k not in model_keys]\n    model_not_in_load = [k for k in model_keys if k not in loaded_keys]\n\n    toload = dict() \n    mismatched_shape_keys = []\n    for k in model_keys:\n        if k in loaded_keys:\n            if model.state_dict()[k].shape != loaded_state_dict[k].shape:\n                mismatched_shape_keys.append(k)\n            else:\n                toload[k] = loaded_state_dict[k]\n\n    if hvd.rank() == 0:\n        _logger.info(\"Keys in loaded but not in model:\")\n        _logger.info(f\"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}\")\n        _logger.info(\"Keys in model but not in loaded:\")\n        _logger.info(f\"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}\")\n        _logger.info(\"Keys in model and loaded, but shape mismatched:\")\n        _logger.info(f\"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}\")\n\n    model.load_state_dict(toload, strict=False)\n\n\ndef load_pretrained_imagenet(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs):\n    import timm\n\n    if hvd.rank() == 0:\n        _logger.info(f\"Loading vit_base_patch16_224 checkpoints.\")\n    loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True).state_dict()\n\n    del loaded_state_dict['head.weight']\n    del loaded_state_dict['head.bias']\n\n    ## Initializing temporal attention\n    new_state_dict = loaded_state_dict.copy()\n    for key in loaded_state_dict:\n        if 'blocks' in key and 'attn' in key:\n            new_key = key.replace('attn','temporal_attn')\n            if not new_key in loaded_state_dict:\n                new_state_dict[new_key] = loaded_state_dict[key]\n            else:\n                new_state_dict[new_key] = loaded_state_dict[new_key]\n        if 'blocks' in key and 'norm1' in key:\n            new_key = key.replace('norm1','temporal_norm1')\n            if not new_key in loaded_state_dict:\n                new_state_dict[new_key] = loaded_state_dict[key]\n            else:\n                new_state_dict[new_key] = loaded_state_dict[new_key]\n\n    loaded_state_dict = new_state_dict\n\n    loaded_keys = loaded_state_dict.keys()\n    model_keys = model.state_dict().keys()\n\n    load_not_in_model = [k for k in loaded_keys if k not in model_keys]\n    model_not_in_load = [k for k in model_keys if k not in loaded_keys]\n\n    toload = dict() \n    mismatched_shape_keys = []\n    for k in model_keys:\n        if k in loaded_keys:\n            if model.state_dict()[k].shape != loaded_state_dict[k].shape:\n                mismatched_shape_keys.append(k)\n            else:\n                toload[k] = loaded_state_dict[k]\n\n    if hvd.rank() == 0:\n        _logger.info(\"Keys in loaded but not in model:\")\n        _logger.info(f\"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}\")\n        _logger.info(\"Keys in model but not in loaded:\")\n        _logger.info(f\"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}\")\n        _logger.info(\"Keys in model and loaded, but shape mismatched:\")\n        _logger.info(f\"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}\")\n\n    model.load_state_dict(toload, strict=False)\n\ndef load_pretrained_kinetics(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs):\n    if cfg is None:\n        cfg = getattr(model, 'default_cfg')\n    if cfg is None or 'url' not in cfg or not cfg['url']:\n        _logger.warning(\"Pretrained model URL is invalid, using random initialization.\")\n        return\n\n    assert len(pretrained_model) > 0, \"Path to pre-trained Kinetics weights not provided.\"\n\n    state_dict = load_state_dict(pretrained_model)\n\n    classifier_name = cfg['classifier']\n    if ignore_classifier:\n\n        classifier_weight_key = classifier_name + '.weight'\n        classifier_bias_key = classifier_name + '.bias'\n\n        state_dict[classifier_weight_key] = model.state_dict()[classifier_weight_key]\n        state_dict[classifier_bias_key] = model.state_dict()[classifier_bias_key]\n\n    else:\n        raise NotImplementedError('[dxli] Not supporting loading Kinetics-pretrained ckpt with classifier.')\n\n    ## Resizing the positional embeddings in case they don't match\n    if num_patches + 1 != state_dict['pos_embed'].size(1):\n        new_pos_embed = resize_spatial_embedding(state_dict, 'pos_embed', num_patches)\n        state_dict['pos_embed'] = new_pos_embed\n\n    ## Resizing time embeddings in case they don't match\n    if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1):\n        state_dict['time_embed'] = resize_temporal_embedding(state_dict, 'time_embed', num_frames) \n\n    ## Loading the weights\n    try:\n        model.load_state_dict(state_dict, strict=True)\n        _logger.info('Succeeded in loading Kinetics pre-trained weights.')\n    except:\n        _logger.error('Error in loading Kinetics pre-trained weights.')\n    \n\ndef resize_spatial_embedding(state_dict, key, num_patches):\n    _logger.info(f\"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}\")\n\n    pos_embed = state_dict[key]\n\n    cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1)\n    other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2)\n\n    new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')\n    new_pos_embed = new_pos_embed.transpose(1, 2)\n    new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)\n\n    return new_pos_embed \n\n\ndef resize_temporal_embedding(state_dict, key, num_frames):\n    _logger.info(f\"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}\")\n\n    time_embed = state_dict[key].transpose(1, 2)\n    new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest')\n    \n    return new_time_embed.transpose(1, 2)"
  },
  {
    "path": "src/modeling/timesformer/linear.py",
    "content": "\"\"\" Linear layer (alternate definition)\n\"\"\"\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn as nn\n\nclass Linear(nn.Linear):\n    def forward(self, input: torch.Tensor) -> torch.Tensor:\n        if torch.jit.is_scripting():\n            bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None\n            return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)\n        else:\n            return F.linear(input, self.weight, self.bias)\n"
  },
  {
    "path": "src/modeling/timesformer/operators.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n\n# \"\"\"Custom operators.\"\"\"\n\n# import torch\n# import torch.nn as nn\n\n\n# class Swish(nn.Module):\n#     \"\"\"Swish activation function: x * sigmoid(x).\"\"\"\n\n#     def __init__(self):\n#         super(Swish, self).__init__()\n\n#     def forward(self, x):\n#         return SwishEfficient.apply(x)\n\n\n# class SwishEfficient(torch.autograd.Function):\n#     \"\"\"Swish activation function: x * sigmoid(x).\"\"\"\n\n#     @staticmethod\n#     def forward(ctx, x):\n#         result = x * torch.sigmoid(x)\n#         ctx.save_for_backward(x)\n#         return result\n\n#     @staticmethod\n#     def backward(ctx, grad_output):\n#         x = ctx.saved_variables[0]\n#         sigmoid_x = torch.sigmoid(x)\n#         return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))\n\n\n# class SE(nn.Module):\n#     \"\"\"Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.\"\"\"\n\n#     def _round_width(self, width, multiplier, min_width=8, divisor=8):\n#         \"\"\"\n#         Round width of filters based on width multiplier\n#         Args:\n#             width (int): the channel dimensions of the input.\n#             multiplier (float): the multiplication factor.\n#             min_width (int): the minimum width after multiplication.\n#             divisor (int): the new width should be dividable by divisor.\n#         \"\"\"\n#         if not multiplier:\n#             return width\n\n#         width *= multiplier\n#         min_width = min_width or divisor\n#         width_out = max(\n#             min_width, int(width + divisor / 2) // divisor * divisor\n#         )\n#         if width_out < 0.9 * width:\n#             width_out += divisor\n#         return int(width_out)\n\n#     def __init__(self, dim_in, ratio, relu_act=True):\n#         \"\"\"\n#         Args:\n#             dim_in (int): the channel dimensions of the input.\n#             ratio (float): the channel reduction ratio for squeeze.\n#             relu_act (bool): whether to use ReLU activation instead\n#                 of Swish (default).\n#             divisor (int): the new width should be dividable by divisor.\n#         \"\"\"\n#         super(SE, self).__init__()\n#         self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))\n#         dim_fc = self._round_width(dim_in, ratio)\n#         self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True)\n#         self.fc1_act = nn.ReLU() if relu_act else Swish()\n#         self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True)\n\n#         self.fc2_sig = nn.Sigmoid()\n\n#     def forward(self, x):\n#         x_in = x\n#         for module in self.children():\n#             x = module(x)\n#         return x_in * x\n"
  },
  {
    "path": "src/modeling/timesformer/vit.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n# Copyright 2020 Ross Wightman\n# Modified Model definition\n\nimport torch\nimport torch.nn as nn\nfrom functools import partial\nimport math\nimport warnings\nimport torch.nn.functional as F\nimport numpy as np\n\nimport torch.utils\nimport torch.utils.checkpoint\n\nfrom src.modeling.timesformer.vit_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\nfrom src.modeling.timesformer.helpers import load_pretrained, load_pretrained_kinetics, load_pretrained_imagenet, load_pretrained_CLIP_ViT\nfrom src.modeling.timesformer.vit_utils import DropPath, to_2tuple, trunc_normal_\n\nfrom src.modeling.xbert import BertAttention\n\n# from .build import MODEL_REGISTRY\nfrom torch import einsum\nfrom einops import rearrange, reduce, repeat\n\nimport src.utils.grad_ckpt as grad_ckpt\nfrom src.utils.logger import LOGGER, TB_LOGGER, add_log_to_file, RunningMeter\n\n\ndef _cfg(url='', **kwargs):\n    return {\n        'url': url,\n        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,\n        'crop_pct': .9, 'interpolation': 'bicubic',\n        'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,\n        'first_conv': 'patch_embed.proj', 'classifier': 'head',\n        **kwargs\n    }\n\n\ndefault_cfgs = {\n    'vit_base_patch16_224': _cfg(\n        url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',\n        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),\n    ),\n}\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n        self.with_qkv = with_qkv\n        if self.with_qkv:\n            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n            self.proj = nn.Linear(dim, dim)\n            self.proj_drop = nn.Dropout(proj_drop)\n        self.attn_drop = nn.Dropout(attn_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        if self.with_qkv:\n            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,\n                                      C // self.num_heads).permute(2, 0, 3, 1, 4)\n            q, k, v = qkv[0], qkv[1], qkv[2]\n        else:\n            qkv = x.reshape(B, N, self.num_heads, C //\n                            self.num_heads).permute(0, 2, 1, 3)\n            q, k, v = qkv, qkv, qkv\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        if self.with_qkv:\n            x = self.proj(x)\n            x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n\n    def __init__(self, dim, num_heads, layer_num, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_type='divided_space_time', \n                 use_grad_checkpointing=False):\n        super().__init__()\n        self.attention_type = attention_type\n        assert(attention_type in ['divided_space_time',\n               'space_only', 'joint_space_time'])\n\n        self.norm1 = norm_layer(dim)\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n        # Temporal Attention Parameters\n        if self.attention_type == 'divided_space_time':\n            self.temporal_norm1 = norm_layer(dim)\n            self.temporal_attn = Attention(\n                dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n            self.temporal_fc = nn.Linear(dim, dim)\n\n        # drop path\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,\n                       act_layer=act_layer, drop=drop)\n\n        # [dxli]\n        self.layer_num = layer_num\n        self.use_grad_checkpointing = use_grad_checkpointing\n\n    def forward(self, x, B, T, W):\n        num_spatial_tokens = (x.size(1) - 1) // T\n        H = num_spatial_tokens // W\n\n        if self.attention_type in ['space_only', 'joint_space_time']:\n            x = x + self.drop_path(self.attn(self.norm1(x)))\n            x = x + self.drop_path(self.mlp(self.norm2(x)))\n            return x\n        elif self.attention_type == 'divided_space_time':\n            # Temporal\n            xt = x[:, 1:, :]\n            xt = rearrange(xt, 'b (h w t) m -> (b h w) t m',\n                           b=B, h=H, w=W, t=T)\n            \n            if self.use_grad_checkpointing:\n                # temporal_attn_out = torch.utils.checkpoint.checkpoint(self.temporal_attn, self.temporal_norm1(xt))\n                temporal_attn_out = grad_ckpt.CheckpointFunction.apply(self.temporal_attn, 1, self.temporal_norm1(xt))\n            else:\n                temporal_attn_out = self.temporal_attn(self.temporal_norm1(xt))\n                # res_temporal = self.drop_path(\n                #     self.temporal_attn(self.temporal_norm1(xt)))\n            res_temporal = self.drop_path(temporal_attn_out)\n\n            res_temporal = rearrange(\n                res_temporal, '(b h w) t m -> b (h w t) m', b=B, h=H, w=W, t=T)\n            res_temporal = self.temporal_fc(res_temporal)\n            xt = x[:, 1:, :] + res_temporal\n\n            # Spatial\n            init_cls_token = x[:, 0, :].unsqueeze(1)\n            cls_token = init_cls_token.repeat(1, T, 1)\n            cls_token = rearrange(\n                cls_token, 'b t m -> (b t) m', b=B, t=T).unsqueeze(1)\n            xs = xt\n            xs = rearrange(xs, 'b (h w t) m -> (b t) (h w) m',\n                           b=B, h=H, w=W, t=T)\n            xs = torch.cat((cls_token, xs), 1)\n\n            # [origial]\n            # res_spatial = self.drop_path(self.attn(self.norm1(xs)))\n            if self.use_grad_checkpointing:\n                spatial_attn_out = grad_ckpt.CheckpointFunction.apply(self.attn, 1, self.norm1(xs))\n            else:\n                # spatial_attn_out = torch.utils.checkpoint.checkpoint(self.attn, self.norm1(xs))\n                spatial_attn_out = self.attn(self.norm1(xs))\n            res_spatial = self.drop_path(spatial_attn_out)\n\n            # Taking care of CLS token\n            cls_token = res_spatial[:, 0, :]\n            cls_token = rearrange(cls_token, '(b t) m -> b t m', b=B, t=T)\n            # averaging for every frame\n            cls_token = torch.mean(cls_token, 1, True)\n            res_spatial = res_spatial[:, 1:, :]\n            res_spatial = rearrange(\n                res_spatial, '(b t) (h w) m -> b (h w t) m', b=B, h=H, w=W, t=T)\n            res = res_spatial\n            x = xt\n\n            # Mlp\n            x = torch.cat((init_cls_token, x), 1) + \\\n                torch.cat((cls_token, res), 1)\n            \n            x_res = x\n\n            x = self.norm2(x)\n            # x = x + self.drop_path(self.mlp(self.norm2(x)))\n\n            # MLP\n            # [origial]\n            # x = x_res + self.drop_path(self.mlp(x))\n            if self.use_grad_checkpointing:\n                # mlp_out = torch.utils.checkpoint.checkpoint(self.mlp, x)\n                mlp_out = grad_ckpt.CheckpointFunction.apply(self.mlp, 1, x)\n            else:\n                mlp_out = self.mlp(x)\n\n            x = x_res + self.drop_path(mlp_out)\n            return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * \\\n            (img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(in_chans, embed_dim,\n                              kernel_size=patch_size, stride=patch_size)\n\n    def forward(self, x):\n        B, C, T, H, W = x.shape\n        x = rearrange(x, 'b c t h w -> (b t) c h w')\n        x = self.proj(x)\n        W = x.size(-1)\n        x = x.flatten(2).transpose(1, 2)\n        return x, T, W\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\" Vision Transformere\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,\n                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,\n                 drop_path_rate=0.1, hybrid_backbone=None, norm_layer=nn.LayerNorm, num_frames=8, attention_type='divided_space_time', dropout=0., \n                 cross_attention_config=None, use_grad_checkpointing=False):\n        super().__init__()\n\n        self.attention_type = attention_type\n        self.depth = depth\n        self.dropout = nn.Dropout(dropout)\n        self.num_classes = num_classes\n        # num_features for consistency with other models\n        self.num_features = self.embed_dim = embed_dim\n        self.patch_embed = PatchEmbed(\n            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        # Positional Embeddings\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim))\n        self.pos_drop = nn.Dropout(p=drop_rate)\n        if self.attention_type != 'space_only':\n            self.time_embed = nn.Parameter(\n                torch.zeros(1, num_frames, embed_dim))\n            self.time_drop = nn.Dropout(p=drop_rate)\n\n        # Attention Blocks\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate,\n                                                self.depth)]  # stochastic depth decay rule\n        self.blocks = nn.ModuleList([\n            Block(layer_num=i, use_grad_checkpointing=use_grad_checkpointing,\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, attention_type=self.attention_type)\n            for i in range(self.depth)])\n        self.norm = norm_layer(embed_dim)\n\n        # Classifier head\n        self.head = nn.Linear(\n            embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n        trunc_normal_(self.pos_embed, std=.02)\n        trunc_normal_(self.cls_token, std=.02)\n        self.apply(self._init_weights)\n\n        # initialization of temporal attention weights\n        if self.attention_type == 'divided_space_time':\n            i = 0\n            for m in self.blocks.modules():\n                m_str = str(m)\n                if 'Block' in m_str:\n                    if i > 0:\n                        nn.init.constant_(m.temporal_fc.weight, 0)\n                        nn.init.constant_(m.temporal_fc.bias, 0)\n                    i += 1\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    @torch.jit.ignore\n    def no_weight_decay(self):\n        return {'pos_embed', 'cls_token', 'time_embed'}\n\n    def get_classifier(self):\n        return self.head\n\n    def reset_classifier(self, num_classes, global_pool=''):\n        self.num_classes = num_classes\n        self.head = nn.Linear(\n            self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()\n\n    def forward_features(self, x, return_all_tokens=False):\n        B = x.shape[0]\n        x, T, W = self.patch_embed(x)\n        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        # resizing the positional embeddings in case they don't match the input at inference\n        if x.size(1) != self.pos_embed.size(1):\n            pos_embed = self.pos_embed\n            cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)\n            other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)\n            P = int(other_pos_embed.size(2) ** 0.5)\n            H = x.size(1) // W\n            other_pos_embed = other_pos_embed.reshape(1, x.size(2), P, P)\n            new_pos_embed = F.interpolate(\n                other_pos_embed, size=(H, W), mode='nearest')\n            new_pos_embed = new_pos_embed.flatten(2)\n            new_pos_embed = new_pos_embed.transpose(1, 2)\n            new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)\n            x = x + new_pos_embed\n        else:\n            x = x + self.pos_embed\n        x = self.pos_drop(x)\n\n        # Time Embeddings\n        if self.attention_type != 'space_only':\n            cls_tokens = x[:B, 0, :].unsqueeze(1)\n            x = x[:, 1:]\n            x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)\n            # Resizing time embeddings in case they don't match\n            if T != self.time_embed.size(1):\n                time_embed = self.time_embed.transpose(1, 2)\n                new_time_embed = F.interpolate(\n                    time_embed, size=(T), mode='nearest')\n                new_time_embed = new_time_embed.transpose(1, 2)\n                x = x + new_time_embed\n            else:\n                x = x + self.time_embed\n            x = self.time_drop(x)\n            x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T)\n            x = torch.cat((cls_tokens, x), dim=1)\n\n        # Attention blocks\n        for blk in self.blocks:\n            x = blk(x, B, T, W)\n\n        # Predictions for space-only baseline\n        if self.attention_type == 'space_only':\n            x = rearrange(x, '(b t) n m -> b t n m', b=B, t=T)\n            x = torch.mean(x, 1)  # averaging predictions for every frame\n\n        x = self.norm(x)\n\n        if return_all_tokens:\n            return x\n        else:\n            return x[:, 0]\n\n    def forward(self, x):\n        x = self.forward_features(x)\n        x = self.head(x)\n        return x\n\n\ndef _conv_filter(state_dict, patch_size=16):\n    \"\"\" convert patch embedding weight from manual patchify + linear proj to conv\"\"\"\n    out_dict = {}\n    for k, v in state_dict.items():\n        if 'patch_embed.proj.weight' in k:\n            if v.shape[-1] != patch_size:\n                patch_size = v.shape[-1]\n            v = v.reshape((v.shape[0], 3, patch_size, patch_size))\n        out_dict[k] = v\n    return out_dict\n\n\nclass vit_base_patch16_224(nn.Module):\n    def __init__(self, cfg, **kwargs):\n        super(vit_base_patch16_224, self).__init__()\n        self.pretrained = True\n        patch_size = 16\n        self.model = VisionTransformer(img_size=cfg.DATA.TRAIN_CROP_SIZE, num_classes=cfg.MODEL.NUM_CLASSES, patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(\n            nn.LayerNorm, eps=1e-6), drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, num_frames=cfg.DATA.NUM_FRAMES, attention_type=cfg.TIMESFORMER.ATTENTION_TYPE, **kwargs)\n\n        self.attention_type = cfg.TIMESFORMER.ATTENTION_TYPE\n        self.model.default_cfg = default_cfgs['vit_base_patch16_224']\n        self.num_patches = (cfg.DATA.TRAIN_CROP_SIZE // patch_size) * \\\n            (cfg.DATA.TRAIN_CROP_SIZE // patch_size)\n        pretrained_model = cfg.TIMESFORMER.PRETRAINED_MODEL\n        if self.pretrained:\n            load_pretrained(self.model, num_classes=self.model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter,\n                            img_size=cfg.DATA.TRAIN_CROP_SIZE, num_patches=self.num_patches, attention_type=self.attention_type, pretrained_model=pretrained_model)\n\n    def forward(self, x):\n        x = self.model(x)\n        return x\n\n\nclass TimeSformer(nn.Module):\n    def __init__(self, model_cfg, input_format='BGR', cross_attention_config=None, **kwargs):\n        super(TimeSformer, self).__init__()\n\n        self.config_file = str(model_cfg)\n\n        # model-specific configurations\n        self.img_size = model_cfg['img_size']\n        self.patch_size = model_cfg['patch_size']\n        self.num_frames = model_cfg['num_frm']\n        self.attn_drop_rate = model_cfg['attn_drop_rate']\n        self.drop_path_rate = model_cfg['drop_path_rate']\n        self.drop_rate = model_cfg['drop_rate']\n        self.use_pooling = model_cfg['use_maxpooling']\n        self.use_grad_ckpt = model_cfg['gradient_checkpointing']\n\n        self.attention_type = 'divided_space_time'\n\n        LOGGER.info(f'Initializing TimeSformer with img_size={self.img_size}, patch_size={self.patch_size}, num_frames={self.num_frames}')\n\n        # will be ignored when loading official pretrained ckpt\n        self.num_classes = 400\n\n        self.input_format = input_format\n        assert input_format == \"RGB\", \"Official TimeSformer uses RGB input.\"\n\n        self.model = VisionTransformer(img_size=self.img_size,\n                                       num_classes=self.num_classes,\n                                       patch_size=self.patch_size,\n                                       embed_dim=768,\n                                       depth=12,\n                                       num_heads=12,\n                                       mlp_ratio=4,\n                                       qkv_bias=True,\n                                       norm_layer=partial(nn.LayerNorm, eps=1e-6),\n                                       drop_rate=self.drop_rate,\n                                       attn_drop_rate=self.attn_drop_rate,\n                                       drop_path_rate=self.drop_path_rate,\n                                       num_frames=self.num_frames,\n                                       attention_type=self.attention_type,\n                                       cross_attention_config=cross_attention_config,\n                                       use_grad_checkpointing=self.use_grad_ckpt,\n                                       **kwargs\n                                       )\n\n        if self.use_pooling:\n            self.maxpool_kernel_size = model_cfg['maxpool_kernel_size']\n            self.maxpooling = torch.nn.MaxPool2d(kernel_size=self.maxpool_kernel_size)\n\n        self.model.default_cfg = default_cfgs['vit_base_patch' + str(self.patch_size)+'_224']\n        self.num_patches = (self.img_size // self.patch_size) * (self.img_size // self.patch_size)\n\n    def forward(self, x):\n        x = self.model(x)\n        return x\n\n    def forward_features(self, x, return_all_tokens=True, pooling='temporal'):\n        b, c, t, h, w = x.shape\n        \n        x = self.model.forward_features(x, return_all_tokens=return_all_tokens)\n\n        ## apply pooling\n        W = H = self.img_size // self.patch_size\n        T = self.num_frames\n\n        cls_tokens = x[:, 0, :].unsqueeze(1)\n        other_tokens = x[:, 1:, :]\n\n        x = rearrange(other_tokens, 'b (h w t) m -> b t (h w) m', h=H, w=W, t=T)\n\n        assert pooling in ['temporal', 'spatial', 'none'], 'Invalid pooling type {}'.format(pooling)\n        if pooling == 'temporal':\n            x = torch.mean(x, dim=1)\n            x = torch.cat((cls_tokens, x), dim=1)\n        elif pooling == 'spatial': # spatial pooling\n            # x = torch.max(x, dim=2)[0]\n            x = torch.mean(x, dim=2)\n            x = torch.cat((cls_tokens, x), dim=1)\n        elif pooling == 'none':\n            cls_tokens_repeat = cls_tokens.unsqueeze(1).repeat(1, T, 1, 1)\n            x = torch.cat((cls_tokens_repeat, x), dim=2)\n        else:\n            raise NotImplementedError('Unsupported pooling type {}'.format(pooling))\n\n        return x\n    \n    def _get_pooled_features(self, x):\n        b, t, h, w, c = x.shape\n\n        # x = rarrange(x.transpose(2, 4).transpose(3, 4), 'b t h w c -> (b t c) h w')\n        x = rearrange(x, 'b t h w c -> (b t c) h w')\n        x = self.maxpooling(x)\n        x = rearrange(x, '(b t c) h w -> b (t h w) c', b=b, t=t)\n\n        return x\n    \n    def load_state_dict(self, pretrained_ckpt_path):\n        LOGGER.info('Loading TimeSformer checkpoints from {}'.format(pretrained_ckpt_path))\n\n        if pretrained_ckpt_path == \"vit_base_patch16_224\":\n            load_ckpt_func = load_pretrained_imagenet\n        elif \"CLIP_ViT\" in pretrained_ckpt_path:\n            load_ckpt_func = load_pretrained_CLIP_ViT\n        else:\n            load_ckpt_func = load_pretrained_kinetics\n\n        load_ckpt_func(self.model,\n                       num_classes=self.model.num_classes,\n                       in_chans=3,\n                       filter_fn=_conv_filter,\n                       img_size=self.img_size,\n                       num_frames=self.num_frames,\n                       num_patches=self.num_patches,\n                       attention_type=self.attention_type,\n                       pretrained_model=pretrained_ckpt_path\n                       )"
  },
  {
    "path": "src/modeling/timesformer/vit_utils.py",
    "content": "# Copyright 2020 Ross Wightman\n# Various utility functions\n\nimport torch\nimport torch.nn as nn\nfrom functools import partial\nimport math\nimport warnings\nimport torch.nn.functional as F\n\nfrom src.modeling.timesformer.helpers import load_pretrained\nfrom itertools import repeat\nimport collections.abc as container_abcs\n\nDEFAULT_CROP_PCT = 0.875\nIMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)\nIMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)\nIMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)\nIMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)\nIMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)\n\ndef _no_grad_trunc_normal_(tensor, mean, std, a, b):\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1. + math.erf(x / math.sqrt(2.))) / 2.\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n                      \"The distribution of values may be incorrect.\",\n                      stacklevel=2)\n\n    with torch.no_grad():\n        # Values are generated by using a truncated uniform distribution and\n        # then using the inverse CDF for the normal distribution.\n        # Get upper and lower cdf values\n        l = norm_cdf((a - mean) / std)\n        u = norm_cdf((b - mean) / std)\n\n        # Uniformly fill tensor with values from [l, u], then translate to\n        # [2l-1, 2u-1].\n        tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n        # Use inverse cdf transform for normal distribution to get truncated\n        # standard normal\n        tensor.erfinv_()\n\n        # Transform to proper mean, std\n        tensor.mul_(std * math.sqrt(2.))\n        tensor.add_(mean)\n\n        # Clamp to ensure it's in the proper range\n        tensor.clamp_(min=a, max=b)\n        return tensor\n\ndef trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):\n    # type: (Tensor, float, float, float, float) -> Tensor\n    r\"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\leq \\text{mean} \\leq b`.\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    Examples:\n        >>> w = torch.empty(3, 5)\n        >>> nn.init.trunc_normal_(w)\n    \"\"\"\n    return _no_grad_trunc_normal_(tensor, mean, std, a, b)\n\n# From PyTorch internals\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, container_abcs.Iterable):\n            return x\n        return tuple(repeat(x, n))\n    return parse\nto_2tuple = _ntuple(2)\n\n# Calculate symmetric padding for a convolution\ndef get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:\n    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2\n    return padding\n\ndef get_padding_value(padding, kernel_size, **kwargs):\n    dynamic = False\n    if isinstance(padding, str):\n        # for any string padding, the padding will be calculated for you, one of three ways\n        padding = padding.lower()\n        if padding == 'same':\n            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact\n            if is_static_pad(kernel_size, **kwargs):\n                # static case, no extra overhead\n                padding = get_padding(kernel_size, **kwargs)\n            else:\n                # dynamic 'SAME' padding, has runtime/GPU memory overhead\n                padding = 0\n                dynamic = True\n        elif padding == 'valid':\n            # 'VALID' padding, same as padding=0\n            padding = 0\n        else:\n            # Default to PyTorch style 'same'-ish symmetric padding\n            padding = get_padding(kernel_size, **kwargs)\n    return padding, dynamic\n\n# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution\ndef get_same_padding(x: int, k: int, s: int, d: int):\n    return max((int(math.ceil(x // s)) - 1) * s + (k - 1) * d + 1 - x, 0)\n\n\n# Can SAME padding for given args be done statically?\ndef is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):\n    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0\n\n\n# Dynamically pad input x with 'SAME' padding for conv with specified args\n#def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):\ndef pad_same(x, k, s, d=(1, 1), value= 0):\n    ih, iw = x.size()[-2:]\n    pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])\n    if pad_h > 0 or pad_w > 0:\n        x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)\n    return x\n\ndef adaptive_pool_feat_mult(pool_type='avg'):\n    if pool_type == 'catavgmax':\n        return 2\n    else:\n        return 1\n\ndef drop_path(x, drop_prob: float = 0., training: bool = False):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n    \"\"\"\n    if drop_prob == 0. or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n    random_tensor.floor_()  # binarize\n    output = x.div(keep_prob) * random_tensor\n    return output\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).\n    \"\"\"\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n"
  },
  {
    "path": "src/modeling/transformers.py",
    "content": "# coding=utf-8\r\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\r\n# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.\r\n#\r\n# Licensed under the Apache License, Version 2.0 (the \"License\");\r\n# you may not use this file except in compliance with the License.\r\n# You may obtain a copy of the License at\r\n#\r\n#     http://www.apache.org/licenses/LICENSE-2.0\r\n#\r\n# Unless required by applicable law or agreed to in writing, software\r\n# distributed under the License is distributed on an \"AS IS\" BASIS,\r\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n# See the License for the specific language governing permissions and\r\n# limitations under the License.\r\n\"\"\"PyTorch BERT model. \"\"\"\r\n\r\n\r\nimport logging\r\nimport math\r\nimport os\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import CrossEntropyLoss, MSELoss\r\n\r\nfrom transformers.activations import gelu, gelu_new, swish\r\nfrom transformers.configuration_bert import BertConfig\r\nfrom transformers.file_utils import (\r\n    add_start_docstrings, add_start_docstrings_to_callable)\r\nfrom transformers.modeling_utils import PreTrainedModel, prune_linear_layer\r\nfrom apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm\r\n\r\n\r\nlogger = logging.getLogger(__name__)\r\n\r\nBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [\r\n    \"bert-base-uncased\",\r\n    \"bert-large-uncased\",\r\n    \"bert-base-cased\",\r\n    \"bert-large-cased\",\r\n    \"bert-base-multilingual-uncased\",\r\n    \"bert-base-multilingual-cased\",\r\n    \"bert-base-chinese\",\r\n    \"bert-base-german-cased\",\r\n    \"bert-large-uncased-whole-word-masking\",\r\n    \"bert-large-cased-whole-word-masking\",\r\n    \"bert-large-uncased-whole-word-masking-finetuned-squad\",\r\n    \"bert-large-cased-whole-word-masking-finetuned-squad\",\r\n    \"bert-base-cased-finetuned-mrpc\",\r\n    \"bert-base-german-dbmdz-cased\",\r\n    \"bert-base-german-dbmdz-uncased\",\r\n    \"cl-tohoku/bert-base-japanese\",\r\n    \"cl-tohoku/bert-base-japanese-whole-word-masking\",\r\n    \"cl-tohoku/bert-base-japanese-char\",\r\n    \"cl-tohoku/bert-base-japanese-char-whole-word-masking\",\r\n    \"TurkuNLP/bert-base-finnish-cased-v1\",\r\n    \"TurkuNLP/bert-base-finnish-uncased-v1\",\r\n    \"wietsedv/bert-base-dutch-cased\",\r\n    # See all BERT models at https://huggingface.co/models?filter=bert\r\n]\r\n\r\n\r\ndef load_tf_weights_in_bert(model, config, tf_checkpoint_path):\r\n    \"\"\" Load tf checkpoints in a pytorch model.\r\n    \"\"\"\r\n    try:\r\n        import re\r\n        import numpy as np\r\n        import tensorflow as tf\r\n    except ImportError:\r\n        logger.error(\r\n            \"Loading a TensorFlow model in PyTorch,\"\r\n            \" requires TensorFlow to be installed. Please see \"\r\n            \"https://www.tensorflow.org/install/ \"\r\n            \"for installation instructions.\"\r\n        )\r\n        raise\r\n    tf_path = os.path.abspath(tf_checkpoint_path)\r\n    logger.info(\"Converting TensorFlow checkpoint from {}\".format(tf_path))\r\n    # Load weights from TF model\r\n    init_vars = tf.train.list_variables(tf_path)\r\n    names = []\r\n    arrays = []\r\n    for name, shape in init_vars:\r\n        logger.info(\"Loading TF weight {} with shape {}\".format(name, shape))\r\n        array = tf.train.load_variable(tf_path, name)\r\n        names.append(name)\r\n        arrays.append(array)\r\n\r\n    for name, array in zip(names, arrays):\r\n        name = name.split(\"/\")\r\n        # adam_v and adam_m are variables used\r\n        # in AdamWeightDecayOptimizer to calculated m and v\r\n        # which are not required for using pretrained model\r\n        if any(\r\n            n in [\"adam_v\", \"adam_m\", \"AdamWeightDecayOptimizer\",\r\n                  \"AdamWeightDecayOptimizer_1\", \"global_step\"]\r\n            for n in name\r\n        ):\r\n            logger.info(\"Skipping {}\".format(\"/\".join(name)))\r\n            continue\r\n        pointer = model\r\n        for m_name in name:\r\n            if re.fullmatch(r\"[A-Za-z]+_\\d+\", m_name):\r\n                scope_names = re.split(r\"_(\\d+)\", m_name)\r\n            else:\r\n                scope_names = [m_name]\r\n            if scope_names[0] == \"kernel\" or scope_names[0] == \"gamma\":\r\n                pointer = getattr(pointer, \"weight\")\r\n            elif scope_names[0] == \"output_bias\" or scope_names[0] == \"beta\":\r\n                pointer = getattr(pointer, \"bias\")\r\n            elif scope_names[0] == \"output_weights\":\r\n                pointer = getattr(pointer, \"weight\")\r\n            elif scope_names[0] == \"squad\":\r\n                pointer = getattr(pointer, \"classifier\")\r\n            else:\r\n                try:\r\n                    pointer = getattr(pointer, scope_names[0])\r\n                except AttributeError:\r\n                    logger.info(\"Skipping {}\".format(\"/\".join(name)))\r\n                    continue\r\n            if len(scope_names) >= 2:\r\n                num = int(scope_names[1])\r\n                pointer = pointer[num]\r\n        if m_name[-11:] == \"_embeddings\":\r\n            pointer = getattr(pointer, \"weight\")\r\n        elif m_name == \"kernel\":\r\n            array = np.transpose(array)\r\n        try:\r\n            assert pointer.shape == array.shape\r\n        except AssertionError as e:\r\n            e.args += (pointer.shape, array.shape)\r\n            raise\r\n        logger.info(\"Initialize PyTorch weight {}\".format(name))\r\n        pointer.data = torch.from_numpy(array)\r\n    return model\r\n\r\n\r\ndef mish(x):\r\n    return x * torch.tanh(nn.functional.softplus(x))\r\n\r\n\r\nACT2FN = {\"gelu\": gelu, \"relu\": torch.nn.functional.relu,\r\n          \"swish\": swish, \"gelu_new\": gelu_new, \"mish\": mish}\r\n\r\n\r\nBertLayerNorm = LayerNorm\r\n\r\n\r\nclass BertEmbeddings(nn.Module):\r\n    \"\"\"Construct the embeddings from word, position and token_type embeddings.\r\n    \"\"\"\r\n\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.word_embeddings = nn.Embedding(\r\n            config.vocab_size, config.hidden_size,\r\n            padding_idx=config.pad_token_id)\r\n        self.position_embeddings = nn.Embedding(\r\n            config.max_position_embeddings, config.hidden_size)\r\n        self.token_type_embeddings = nn.Embedding(\r\n            config.type_vocab_size, config.hidden_size)\r\n\r\n        # self.LayerNorm is not snake-cased to stick with\r\n        # TensorFlow model variable name and be able to load\r\n        # any TensorFlow checkpoint file\r\n        self.LayerNorm = BertLayerNorm(\r\n            config.hidden_size, eps=config.layer_norm_eps)\r\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\r\n\r\n    def forward(self, input_ids=None, token_type_ids=None,\r\n                position_ids=None, inputs_embeds=None):\r\n        if input_ids is not None:\r\n            input_shape = input_ids.size()\r\n        else:\r\n            input_shape = inputs_embeds.size()[:-1]\r\n\r\n        seq_length = input_shape[1]\r\n        device = input_ids.device if input_ids is not None\\\r\n            else inputs_embeds.device\r\n        if position_ids is None:\r\n            position_ids = torch.arange(\r\n                seq_length, dtype=torch.long, device=device)\r\n            position_ids = position_ids.unsqueeze(0).expand(input_shape)\r\n        if token_type_ids is None:\r\n            token_type_ids = torch.zeros(\r\n                input_shape, dtype=torch.long, device=device)\r\n\r\n        if inputs_embeds is None:\r\n            inputs_embeds = self.word_embeddings(input_ids)\r\n        position_embeddings = self.position_embeddings(position_ids)\r\n        token_type_embeddings = self.token_type_embeddings(token_type_ids)\r\n\r\n        embeddings = (\r\n            inputs_embeds + position_embeddings + token_type_embeddings)\r\n        embeddings = self.LayerNorm(embeddings)\r\n        embeddings = self.dropout(embeddings)\r\n        return embeddings\r\n\r\n\r\nclass BertSelfAttention(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        if config.hidden_size % config.num_attention_heads != 0 and\\\r\n                not hasattr(config, \"embedding_size\"):\r\n            raise ValueError(\r\n                \"The hidden size (%d) is not a multiple of the number of attention \"\r\n                \"heads (%d)\" % (config.hidden_size, config.num_attention_heads)\r\n            )\r\n        self.output_attentions = config.output_attentions\r\n\r\n        self.num_attention_heads = config.num_attention_heads\r\n        self.attention_head_size = int(\r\n            config.hidden_size / config.num_attention_heads)\r\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\r\n\r\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\r\n        self.key = nn.Linear(config.hidden_size, self.all_head_size)\r\n        self.value = nn.Linear(config.hidden_size, self.all_head_size)\r\n\r\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\r\n\r\n    def transpose_for_scores(self, x):\r\n        new_x_shape = x.size()[:-1] + (\r\n            self.num_attention_heads, self.attention_head_size)\r\n        x = x.view(*new_x_shape)\r\n        return x.permute(0, 2, 1, 3)\r\n\r\n    def forward(\r\n        self,\r\n        hidden_states,\r\n        attention_mask=None,\r\n        head_mask=None,\r\n        encoder_hidden_states=None,\r\n        encoder_attention_mask=None,\r\n    ):\r\n        mixed_query_layer = self.query(hidden_states)\r\n\r\n        # If this is instantiated as a cross-attention module, the keys\r\n        # and values come from an encoder; the attention mask needs to be\r\n        # such that the encoder's padding tokens are not attended to.\r\n        if encoder_hidden_states is not None:\r\n            mixed_key_layer = self.key(encoder_hidden_states)\r\n            mixed_value_layer = self.value(encoder_hidden_states)\r\n            attention_mask = encoder_attention_mask\r\n        else:\r\n            mixed_key_layer = self.key(hidden_states)\r\n            mixed_value_layer = self.value(hidden_states)\r\n\r\n        query_layer = self.transpose_for_scores(mixed_query_layer)\r\n        key_layer = self.transpose_for_scores(mixed_key_layer)\r\n        value_layer = self.transpose_for_scores(mixed_value_layer)\r\n\r\n        # Take the dot product between \"query\" and \"key\"\r\n        # to get the raw attention scores.\r\n        attention_scores = torch.matmul(\r\n            query_layer, key_layer.transpose(-1, -2))\r\n        attention_scores = attention_scores / math.sqrt(\r\n            self.attention_head_size)\r\n        if attention_mask is not None:\r\n            # Apply the attention mask is\r\n            # (precomputed for all layers in BertModel forward() function)\r\n            attention_scores = attention_scores + attention_mask\r\n\r\n        # Normalize the attention scores to probabilities.\r\n        attention_probs = nn.Softmax(dim=-1)(attention_scores)\r\n\r\n        # This is actually dropping out entire tokens to attend to, which might\r\n        # seem a bit unusual, but is taken from the original Transformer paper.\r\n        attention_probs = self.dropout(attention_probs)\r\n\r\n        # Mask heads if we want to\r\n        if head_mask is not None:\r\n            attention_probs = attention_probs * head_mask\r\n\r\n        context_layer = torch.matmul(attention_probs, value_layer)\r\n\r\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\r\n        new_context_layer_shape = context_layer.size()[:-2] + (\r\n            self.all_head_size,)\r\n        context_layer = context_layer.view(*new_context_layer_shape)\r\n\r\n        outputs = (context_layer, attention_probs\r\n                   ) if self.output_attentions else (context_layer,)\r\n        return outputs\r\n\r\n\r\nclass BertSelfOutput(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\r\n        self.LayerNorm = BertLayerNorm(\r\n            config.hidden_size, eps=config.layer_norm_eps)\r\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\r\n\r\n    def forward(self, hidden_states, input_tensor):\r\n        hidden_states = self.dense(hidden_states)\r\n        hidden_states = self.dropout(hidden_states)\r\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\r\n        return hidden_states\r\n\r\n\r\nclass BertAttention(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.self = BertSelfAttention(config)\r\n        self.output = BertSelfOutput(config)\r\n        self.pruned_heads = set()\r\n\r\n    def prune_heads(self, heads):\r\n        if len(heads) == 0:\r\n            return\r\n        mask = torch.ones(self.self.num_attention_heads,\r\n                          self.self.attention_head_size)\r\n        heads = set(heads) - self.pruned_heads  # Convert to set and remove already pruned heads\r\n        for head in heads:\r\n            # Compute how many pruned heads are\r\n            # before the head and move the index accordingly\r\n            head = head - sum(1 if h < head else 0 for h in self.pruned_heads)\r\n            mask[head] = 0\r\n        mask = mask.view(-1).contiguous().eq(1)\r\n        index = torch.arange(len(mask))[mask].long()\r\n\r\n        # Prune linear layers\r\n        self.self.query = prune_linear_layer(self.self.query, index)\r\n        self.self.key = prune_linear_layer(self.self.key, index)\r\n        self.self.value = prune_linear_layer(self.self.value, index)\r\n        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\r\n\r\n        # Update hyper params and store pruned heads\r\n        self.self.num_attention_heads = self.self.num_attention_heads - len(\r\n            heads)\r\n        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads\r\n        self.pruned_heads = self.pruned_heads.union(heads)\r\n\r\n    def forward(\r\n        self,\r\n        hidden_states,\r\n        attention_mask=None,\r\n        head_mask=None,\r\n        encoder_hidden_states=None,\r\n        encoder_attention_mask=None,\r\n    ):\r\n        self_outputs = self.self(\r\n            hidden_states, attention_mask, head_mask,\r\n            encoder_hidden_states, encoder_attention_mask\r\n        )\r\n        attention_output = self.output(self_outputs[0], hidden_states)\r\n        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them\r\n        return outputs\r\n\r\n\r\nclass BertIntermediate(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\r\n        if isinstance(config.hidden_act, str):\r\n            self.intermediate_act_fn = ACT2FN[config.hidden_act]\r\n        else:\r\n            self.intermediate_act_fn = config.hidden_act\r\n\r\n    def forward(self, hidden_states):\r\n        hidden_states = self.dense(hidden_states)\r\n        hidden_states = self.intermediate_act_fn(hidden_states)\r\n        return hidden_states\r\n\r\n\r\nclass BertOutput(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\r\n        self.LayerNorm = BertLayerNorm(\r\n            config.hidden_size, eps=config.layer_norm_eps)\r\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\r\n\r\n    def forward(self, hidden_states, input_tensor):\r\n        hidden_states = self.dense(hidden_states)\r\n        hidden_states = self.dropout(hidden_states)\r\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\r\n        return hidden_states\r\n\r\n\r\nclass BertLayer(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.attention = BertAttention(config)\r\n        self.is_decoder = config.is_decoder\r\n        if self.is_decoder:\r\n            self.crossattention = BertAttention(config)\r\n        self.intermediate = BertIntermediate(config)\r\n        self.output = BertOutput(config)\r\n\r\n    def forward(\r\n        self,\r\n        hidden_states,\r\n        attention_mask=None,\r\n        head_mask=None,\r\n        encoder_hidden_states=None,\r\n        encoder_attention_mask=None,\r\n    ):\r\n        self_attention_outputs = self.attention(\r\n            hidden_states, attention_mask, head_mask)\r\n        attention_output = self_attention_outputs[0]\r\n        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights\r\n\r\n        if self.is_decoder and encoder_hidden_states is not None:\r\n            cross_attention_outputs = self.crossattention(\r\n                attention_output, attention_mask, head_mask,\r\n                encoder_hidden_states, encoder_attention_mask\r\n            )\r\n            attention_output = cross_attention_outputs[0]\r\n            outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights\r\n\r\n        intermediate_output = self.intermediate(attention_output)\r\n        layer_output = self.output(intermediate_output, attention_output)\r\n        outputs = (layer_output,) + outputs\r\n        return outputs\r\n\r\n\r\nclass BertEncoder(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.output_attentions = config.output_attentions\r\n        self.output_hidden_states = config.output_hidden_states\r\n        self.layer = nn.ModuleList([BertLayer(config) for _ in range(\r\n            config.num_hidden_layers)])\r\n\r\n    def forward(\r\n        self,\r\n        hidden_states,\r\n        attention_mask=None,\r\n        head_mask=None,\r\n        encoder_hidden_states=None,\r\n        encoder_attention_mask=None,\r\n    ):\r\n        all_hidden_states = ()\r\n        all_attentions = ()\r\n        for i, layer_module in enumerate(self.layer):\r\n            if self.output_hidden_states:\r\n                all_hidden_states = all_hidden_states + (hidden_states,)\r\n\r\n            layer_outputs = layer_module(\r\n                hidden_states, attention_mask, head_mask[i],\r\n                encoder_hidden_states, encoder_attention_mask\r\n            )\r\n            hidden_states = layer_outputs[0]\r\n\r\n            if self.output_attentions:\r\n                all_attentions = all_attentions + (layer_outputs[1],)\r\n\r\n        # Add last layer\r\n        if self.output_hidden_states:\r\n            all_hidden_states = all_hidden_states + (hidden_states,)\r\n\r\n        outputs = (hidden_states,)\r\n        if self.output_hidden_states:\r\n            outputs = outputs + (all_hidden_states,)\r\n        if self.output_attentions:\r\n            outputs = outputs + (all_attentions,)\r\n        return outputs  # last-layer hidden state, (all hidden states), (all attentions)\r\n\r\n\r\nclass BertPooler(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\r\n        self.activation = nn.Tanh()\r\n\r\n    def forward(self, hidden_states):\r\n        # We \"pool\" the model by simply taking the hidden state corresponding\r\n        # to the first token.\r\n        first_token_tensor = hidden_states[:, 0]\r\n        pooled_output = self.dense(first_token_tensor)\r\n        pooled_output = self.activation(pooled_output)\r\n        return pooled_output\r\n\r\n\r\nclass BertPredictionHeadTransform(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\r\n        if isinstance(config.hidden_act, str):\r\n            self.transform_act_fn = ACT2FN[config.hidden_act]\r\n        else:\r\n            self.transform_act_fn = config.hidden_act\r\n        self.LayerNorm = BertLayerNorm(\r\n            config.hidden_size, eps=config.layer_norm_eps)\r\n\r\n    def forward(self, hidden_states):\r\n        hidden_states = self.dense(hidden_states)\r\n        hidden_states = self.transform_act_fn(hidden_states)\r\n        hidden_states = self.LayerNorm(hidden_states)\r\n        return hidden_states\r\n\r\n\r\nclass BertLMPredictionHead(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.transform = BertPredictionHeadTransform(config)\r\n\r\n        # The output weights are the same as the input embeddings, but there is\r\n        # an output-only bias for each token.\r\n        self.decoder = nn.Linear(\r\n            config.hidden_size, config.vocab_size, bias=False)\r\n\r\n        self.bias = nn.Parameter(torch.zeros(config.vocab_size))\r\n\r\n        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`\r\n        self.decoder.bias = self.bias\r\n\r\n    def forward(self, hidden_states):\r\n        hidden_states = self.transform(hidden_states)\r\n        hidden_states = self.decoder(hidden_states)\r\n        return hidden_states\r\n\r\n\r\nclass BertOnlyMLMHead(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.predictions = BertLMPredictionHead(config)\r\n\r\n    def forward(self, sequence_output):\r\n        prediction_scores = self.predictions(sequence_output)\r\n        return prediction_scores\r\n\r\n\r\nclass BertOnlyNSPHead(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.seq_relationship = nn.Linear(config.hidden_size, 2)\r\n\r\n    def forward(self, pooled_output):\r\n        seq_relationship_score = self.seq_relationship(pooled_output)\r\n        return seq_relationship_score\r\n\r\n\r\nclass BertPreTrainingHeads(nn.Module):\r\n    def __init__(self, config):\r\n        super().__init__()\r\n        self.predictions = BertLMPredictionHead(config)\r\n        self.seq_relationship = nn.Linear(config.hidden_size, 2)\r\n\r\n    def forward(self, sequence_output, pooled_output):\r\n        prediction_scores = self.predictions(sequence_output)\r\n        seq_relationship_score = self.seq_relationship(pooled_output)\r\n        return prediction_scores, seq_relationship_score\r\n\r\n\r\nclass BertPreTrainedModel(PreTrainedModel):\r\n    \"\"\" An abstract class to handle weights initialization and\r\n        a simple interface for downloading and loading pretrained models.\r\n    \"\"\"\r\n\r\n    config_class = BertConfig\r\n    load_tf_weights = load_tf_weights_in_bert\r\n    base_model_prefix = \"bert\"\r\n\r\n    def _init_weights(self, module):\r\n        \"\"\" Initialize the weights \"\"\"\r\n        if isinstance(module, (nn.Linear, nn.Embedding)):\r\n            # Slightly different from the TF version which uses truncated_normal for initialization\r\n            # cf https://github.com/pytorch/pytorch/pull/5617\r\n            module.weight.data.normal_(\r\n                mean=0.0, std=self.config.initializer_range)\r\n        elif isinstance(module, BertLayerNorm):\r\n            module.bias.data.zero_()\r\n            module.weight.data.fill_(1.0)\r\n        if isinstance(module, nn.Linear) and module.bias is not None:\r\n            module.bias.data.zero_()\r\n"
  },
  {
    "path": "src/modeling/xbert.py",
    "content": "# coding=utf-8\n# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.\n# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch BERT model. \"\"\"\n\nimport math\nimport os\nimport warnings\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor, device, dtype, nn\nimport torch.utils.checkpoint\nfrom torch import nn\nfrom torch.nn import CrossEntropyLoss, MSELoss\nimport torch.nn.functional as F\n\nfrom transformers.activations import ACT2FN\nfrom transformers.file_utils import (\n    ModelOutput,\n    add_code_sample_docstrings,\n    add_start_docstrings,\n    add_start_docstrings_to_model_forward,\n    replace_return_docstrings,\n)\nfrom transformers.modeling_outputs import (\n    BaseModelOutputWithPastAndCrossAttentions,\n    BaseModelOutputWithPoolingAndCrossAttentions,\n    CausalLMOutputWithCrossAttentions,\n    MaskedLMOutput,\n    MultipleChoiceModelOutput,\n    NextSentencePredictorOutput,\n    QuestionAnsweringModelOutput,\n    SequenceClassifierOutput,\n    TokenClassifierOutput,\n)\nfrom transformers.modeling_utils import (\n    PreTrainedModel,\n    apply_chunking_to_forward,\n    find_pruneable_heads_and_indices,\n    prune_linear_layer,\n)\nfrom transformers.utils import logging\nfrom transformers.models.bert.configuration_bert import BertConfig\n\n\nlogger = logging.get_logger(__name__)\n\n_CONFIG_FOR_DOC = \"BertConfig\"\n_TOKENIZER_FOR_DOC = \"BertTokenizer\"\n\nBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [\n    \"bert-base-uncased\",\n    \"bert-large-uncased\",\n    \"bert-base-cased\",\n    \"bert-large-cased\",\n    \"bert-base-multilingual-uncased\",\n    \"bert-base-multilingual-cased\",\n    \"bert-base-chinese\",\n    \"bert-base-german-cased\",\n    \"bert-large-uncased-whole-word-masking\",\n    \"bert-large-cased-whole-word-masking\",\n    \"bert-large-uncased-whole-word-masking-finetuned-squad\",\n    \"bert-large-cased-whole-word-masking-finetuned-squad\",\n    \"bert-base-cased-finetuned-mrpc\",\n    \"bert-base-german-dbmdz-cased\",\n    \"bert-base-german-dbmdz-uncased\",\n    \"cl-tohoku/bert-base-japanese\",\n    \"cl-tohoku/bert-base-japanese-whole-word-masking\",\n    \"cl-tohoku/bert-base-japanese-char\",\n    \"cl-tohoku/bert-base-japanese-char-whole-word-masking\",\n    \"TurkuNLP/bert-base-finnish-cased-v1\",\n    \"TurkuNLP/bert-base-finnish-uncased-v1\",\n    \"wietsedv/bert-base-dutch-cased\",\n    # See all BERT models at https://huggingface.co/models?filter=bert\n]\n\n\ndef load_tf_weights_in_bert(model, config, tf_checkpoint_path):\n    \"\"\"Load tf checkpoints in a pytorch model.\"\"\"\n    try:\n        import re\n\n        import numpy as np\n        import tensorflow as tf\n    except ImportError:\n        logger.error(\n            \"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see \"\n            \"https://www.tensorflow.org/install/ for installation instructions.\"\n        )\n        raise\n    tf_path = os.path.abspath(tf_checkpoint_path)\n    logger.info(\"Converting TensorFlow checkpoint from {}\".format(tf_path))\n    # Load weights from TF model\n    init_vars = tf.train.list_variables(tf_path)\n    names = []\n    arrays = []\n    for name, shape in init_vars:\n        logger.info(\"Loading TF weight {} with shape {}\".format(name, shape))\n        array = tf.train.load_variable(tf_path, name)\n        names.append(name)\n        arrays.append(array)\n\n    for name, array in zip(names, arrays):\n        name = name.split(\"/\")\n        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v\n        # which are not required for using pretrained model\n        if any(\n            n in [\"adam_v\", \"adam_m\", \"AdamWeightDecayOptimizer\", \"AdamWeightDecayOptimizer_1\", \"global_step\"]\n            for n in name\n        ):\n            logger.info(\"Skipping {}\".format(\"/\".join(name)))\n            continue\n        pointer = model\n        for m_name in name:\n            if re.fullmatch(r\"[A-Za-z]+_\\d+\", m_name):\n                scope_names = re.split(r\"_(\\d+)\", m_name)\n            else:\n                scope_names = [m_name]\n            if scope_names[0] == \"kernel\" or scope_names[0] == \"gamma\":\n                pointer = getattr(pointer, \"weight\")\n            elif scope_names[0] == \"output_bias\" or scope_names[0] == \"beta\":\n                pointer = getattr(pointer, \"bias\")\n            elif scope_names[0] == \"output_weights\":\n                pointer = getattr(pointer, \"weight\")\n            elif scope_names[0] == \"squad\":\n                pointer = getattr(pointer, \"classifier\")\n            else:\n                try:\n                    pointer = getattr(pointer, scope_names[0])\n                except AttributeError:\n                    logger.info(\"Skipping {}\".format(\"/\".join(name)))\n                    continue\n            if len(scope_names) >= 2:\n                num = int(scope_names[1])\n                pointer = pointer[num]\n        if m_name[-11:] == \"_embeddings\":\n            pointer = getattr(pointer, \"weight\")\n        elif m_name == \"kernel\":\n            array = np.transpose(array)\n        try:\n            assert (\n                pointer.shape == array.shape\n            ), f\"Pointer shape {pointer.shape} and array shape {array.shape} mismatched\"\n        except AssertionError as e:\n            e.args += (pointer.shape, array.shape)\n            raise\n        logger.info(\"Initialize PyTorch weight {}\".format(name))\n        pointer.data = torch.from_numpy(array)\n    return model\n\n\nclass BertEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word, position and token_type embeddings.\"\"\"\n\n    def __init__(self, config):\n        super().__init__()\n        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)\n        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)\n        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)\n\n        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load\n        # any TensorFlow checkpoint file\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n        # position_ids (1, len position emb) is contiguous in memory and exported when serialized\n        self.register_buffer(\"position_ids\", torch.arange(config.max_position_embeddings).expand((1, -1)))\n        self.position_embedding_type = getattr(config, \"position_embedding_type\", \"absolute\")\n        \n        self.config = config\n\n    def forward(\n        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0\n    ):\n        if input_ids is not None:\n            input_shape = input_ids.size()\n        else:\n            input_shape = inputs_embeds.size()[:-1]\n\n        seq_length = input_shape[1]\n\n        if position_ids is None:\n            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]\n\n        if token_type_ids is None:\n            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)\n\n        if inputs_embeds is None:\n            inputs_embeds = self.word_embeddings(input_ids)\n        \n        token_type_embeddings = self.token_type_embeddings(token_type_ids)  \n\n        embeddings = inputs_embeds + token_type_embeddings\n        if self.position_embedding_type == \"absolute\":\n            position_embeddings = self.position_embeddings(position_ids)\n            embeddings += position_embeddings\n        embeddings = self.LayerNorm(embeddings)\n        embeddings = self.dropout(embeddings)\n        return embeddings\n\n\nclass BertSelfAttention(nn.Module):\n    def __init__(self, config, is_cross_attention):\n        super().__init__()\n        self.config = config\n        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, \"embedding_size\"):\n            raise ValueError(\n                \"The hidden size (%d) is not a multiple of the number of attention \"\n                \"heads (%d)\" % (config.hidden_size, config.num_attention_heads)\n            )\n        \n        self.num_attention_heads = config.num_attention_heads\n        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)\n        self.all_head_size = self.num_attention_heads * self.attention_head_size\n\n        self.query = nn.Linear(config.hidden_size, self.all_head_size)\n        if is_cross_attention:\n            self.key = nn.Linear(config.encoder_width, self.all_head_size)\n            self.value = nn.Linear(config.encoder_width, self.all_head_size)\n        else:\n            self.key = nn.Linear(config.hidden_size, self.all_head_size)\n            self.value = nn.Linear(config.hidden_size, self.all_head_size)\n\n        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)\n        self.position_embedding_type = getattr(config, \"position_embedding_type\", \"absolute\")\n        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n            self.max_position_embeddings = config.max_position_embeddings\n            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)\n        self.save_attention = False   \n            \n    def save_attn_gradients(self, attn_gradients):\n        self.attn_gradients = attn_gradients\n        \n    def get_attn_gradients(self):\n        return self.attn_gradients\n    \n    def save_attention_map(self, attention_map):\n        self.attention_map = attention_map\n        \n    def get_attention_map(self):\n        return self.attention_map\n    \n\n    def transpose_for_scores(self, x):\n        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)\n        x = x.view(*new_x_shape)\n        return x.permute(0, 2, 1, 3)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        mixed_query_layer = self.query(hidden_states)\n\n        # If this is instantiated as a cross-attention module, the keys\n        # and values come from an encoder; the attention mask needs to be\n        # such that the encoder's padding tokens are not attended to.\n        is_cross_attention = encoder_hidden_states is not None\n\n\n        if is_cross_attention:\n            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))\n            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))\n            attention_mask = encoder_attention_mask\n        elif past_key_value is not None:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)\n            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)\n        else:\n            key_layer = self.transpose_for_scores(self.key(hidden_states))\n            value_layer = self.transpose_for_scores(self.value(hidden_states))\n\n        query_layer = self.transpose_for_scores(mixed_query_layer)\n\n        past_key_value = (key_layer, value_layer)\n\n        # Take the dot product between \"query\" and \"key\" to get the raw attention scores.\n        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n\n        if self.position_embedding_type == \"relative_key\" or self.position_embedding_type == \"relative_key_query\":\n            seq_length = hidden_states.size()[1]\n            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)\n            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)\n            distance = position_ids_l - position_ids_r\n            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)\n            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility\n\n            if self.position_embedding_type == \"relative_key\":\n                relative_position_scores = torch.einsum(\"bhld,lrd->bhlr\", query_layer, positional_embedding)\n                attention_scores = attention_scores + relative_position_scores\n            elif self.position_embedding_type == \"relative_key_query\":\n                relative_position_scores_query = torch.einsum(\"bhld,lrd->bhlr\", query_layer, positional_embedding)\n                relative_position_scores_key = torch.einsum(\"bhrd,lrd->bhlr\", key_layer, positional_embedding)\n                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key\n\n        attention_scores = attention_scores / math.sqrt(self.attention_head_size)\n        if attention_mask is not None:\n            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)\n            attention_scores = attention_scores + attention_mask\n\n        # Normalize the attention scores to probabilities.\n        attention_probs = nn.Softmax(dim=-1)(attention_scores)\n        \n        if is_cross_attention and self.save_attention:\n            self.save_attention_map(attention_probs)\n            attention_probs.register_hook(self.save_attn_gradients)         \n\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        attention_probs_dropped = self.dropout(attention_probs)\n\n        # Mask heads if we want to\n        if head_mask is not None:\n            attention_probs_dropped = attention_probs_dropped * head_mask\n\n        context_layer = torch.matmul(attention_probs_dropped, value_layer)\n\n        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)\n\n        outputs = outputs + (past_key_value,)\n        return outputs\n\n\nclass BertSelfOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertAttention(nn.Module):\n    def __init__(self, config, is_cross_attention=False):\n        super().__init__()\n        self.self = BertSelfAttention(config, is_cross_attention)\n        self.output = BertSelfOutput(config)\n        self.pruned_heads = set()\n\n    def prune_heads(self, heads):\n        if len(heads) == 0:\n            return\n        heads, index = find_pruneable_heads_and_indices(\n            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads\n        )\n\n        # Prune linear layers\n        self.self.query = prune_linear_layer(self.self.query, index)\n        self.self.key = prune_linear_layer(self.self.key, index)\n        self.self.value = prune_linear_layer(self.self.value, index)\n        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)\n\n        # Update hyper params and store pruned heads\n        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)\n        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads\n        self.pruned_heads = self.pruned_heads.union(heads)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        self_outputs = self.self(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            encoder_hidden_states,\n            encoder_attention_mask,\n            past_key_value,\n            output_attentions,\n        )\n        attention_output = self.output(self_outputs[0], hidden_states)\n        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them\n        return outputs\n\n\nclass BertIntermediate(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n        if isinstance(config.hidden_act, str):\n            self.intermediate_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.intermediate_act_fn = config.hidden_act\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n        return hidden_states\n\n\nclass BertOutput(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n\n    def forward(self, hidden_states, input_tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertLayer(nn.Module):\n    def __init__(self, config, layer_num):\n        super().__init__()\n        self.config = config\n        self.chunk_size_feed_forward = config.chunk_size_feed_forward\n        self.seq_len_dim = 1\n        self.attention = BertAttention(config)\n\n        # self.has_cross_attention = (layer_num >= config.fusion_layer)\n        self.has_cross_attention = False \n        self.layer_num = layer_num                \n        if self.has_cross_attention:           \n            self.crossattention = BertAttention(config, is_cross_attention=True)\n        self.intermediate = BertIntermediate(config)\n        self.output = BertOutput(config)\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_value=None,\n        output_attentions=False,\n    ):\n        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2\n        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None\n        self_attention_outputs = self.attention(\n            hidden_states,\n            attention_mask,\n            head_mask,\n            output_attentions=output_attentions,\n            past_key_value=self_attn_past_key_value,\n        )\n        attention_output = self_attention_outputs[0]\n\n        outputs = self_attention_outputs[1:-1]\n        present_key_value = self_attention_outputs[-1]\n\n        if self.has_cross_attention:\n            assert encoder_hidden_states is not None, \"encoder_hidden_states must be given for cross-attention layers\"\n            \n            if type(encoder_hidden_states) == list:\n                cross_attention_outputs = self.crossattention(\n                    attention_output,\n                    attention_mask,\n                    head_mask,\n                    encoder_hidden_states[(self.layer_num-self.config.fusion_layer)%len(encoder_hidden_states)],\n                    encoder_attention_mask[(self.layer_num-self.config.fusion_layer)%len(encoder_hidden_states)],\n                    output_attentions=output_attentions,\n                )    \n                attention_output = cross_attention_outputs[0]\n                outputs = outputs + cross_attention_outputs[1:-1]\n         \n            else:\n                cross_attention_outputs = self.crossattention(\n                    attention_output,\n                    attention_mask,\n                    head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    output_attentions=output_attentions,\n                )\n                attention_output = cross_attention_outputs[0]\n                outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights                               \n        layer_output = apply_chunking_to_forward(\n            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output\n        )\n        outputs = (layer_output,) + outputs\n\n        outputs = outputs + (present_key_value,)\n\n        return outputs\n\n    def feed_forward_chunk(self, attention_output):\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n        return layer_output\n\n\nclass BertEncoder(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.config = config\n        self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])\n\n    def forward(\n        self,\n        hidden_states,\n        attention_mask=None,\n        head_mask=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=False,\n        output_hidden_states=False,\n        return_dict=True,\n        mode='multi_modal',\n    ):\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attentions = () if output_attentions else None\n        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None\n\n        next_decoder_cache = () if use_cache else None\n        \n                \n        if mode=='text': \n            start_layer = 0\n            output_layer = self.config.fusion_layer\n            \n        elif mode=='fusion':\n            start_layer = self.config.fusion_layer\n            output_layer = self.config.num_hidden_layers\n            \n        elif mode=='multi_modal':\n            start_layer = 0\n            output_layer = self.config.num_hidden_layers        \n        \n        for i in range(start_layer, output_layer):\n            layer_module = self.layer[i]\n            if output_hidden_states:\n                all_hidden_states = all_hidden_states + (hidden_states,)\n\n            layer_head_mask = head_mask[i] if head_mask is not None else None\n            past_key_value = past_key_values[i] if past_key_values is not None else None\n\n            if getattr(self.config, \"gradient_checkpointing\", False) and self.training:\n\n                if use_cache:\n                    logger.warn(\n                        \"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting \"\n                        \"`use_cache=False`...\"\n                    )\n                    use_cache = False\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs, past_key_value, output_attentions)\n\n                    return custom_forward\n\n                layer_outputs = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(layer_module),\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                )\n            else:\n                layer_outputs = layer_module(\n                    hidden_states,\n                    attention_mask,\n                    layer_head_mask,\n                    encoder_hidden_states,\n                    encoder_attention_mask,\n                    past_key_value,\n                    output_attentions,\n                )\n\n            hidden_states = layer_outputs[0]\n            if use_cache:\n                next_decoder_cache += (layer_outputs[-1],)\n            if output_attentions:\n                all_self_attentions = all_self_attentions + (layer_outputs[1],)\n\n        if output_hidden_states:\n            all_hidden_states = all_hidden_states + (hidden_states,)\n\n        if not return_dict:\n            return tuple(\n                v\n                for v in [\n                    hidden_states,\n                    next_decoder_cache,\n                    all_hidden_states,\n                    all_self_attentions,\n                    all_cross_attentions,\n                ]\n                if v is not None\n            )\n        return BaseModelOutputWithPastAndCrossAttentions(\n            last_hidden_state=hidden_states,\n            past_key_values=next_decoder_cache,\n            hidden_states=all_hidden_states,\n            attentions=all_self_attentions,\n            cross_attentions=all_cross_attentions,\n        )\n\n\nclass BertPooler(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.activation = nn.Tanh()\n\n    def forward(self, hidden_states):\n        # We \"pool\" the model by simply taking the hidden state corresponding\n        # to the first token.\n        first_token_tensor = hidden_states[:, 0]\n        pooled_output = self.dense(first_token_tensor)\n        pooled_output = self.activation(pooled_output)\n        return pooled_output\n\n\nclass BertPredictionHeadTransform(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        if isinstance(config.hidden_act, str):\n            self.transform_act_fn = ACT2FN[config.hidden_act]\n        else:\n            self.transform_act_fn = config.hidden_act\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.transform_act_fn(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states)\n        return hidden_states\n\n\nclass BertLMPredictionHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.transform = BertPredictionHeadTransform(config)\n\n        # The output weights are the same as the input embeddings, but there is\n        # an output-only bias for each token.\n        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n        self.bias = nn.Parameter(torch.zeros(config.vocab_size))\n\n        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`\n        self.decoder.bias = self.bias\n\n    def forward(self, hidden_states):\n        hidden_states = self.transform(hidden_states)\n        hidden_states = self.decoder(hidden_states)\n        return hidden_states\n\n\nclass BertOnlyMLMHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.predictions = BertLMPredictionHead(config)\n\n    def forward(self, sequence_output):\n        prediction_scores = self.predictions(sequence_output)\n        return prediction_scores\n\n\nclass BertOnlyNSPHead(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.seq_relationship = nn.Linear(config.hidden_size, 2)\n\n    def forward(self, pooled_output):\n        seq_relationship_score = self.seq_relationship(pooled_output)\n        return seq_relationship_score\n\n\nclass BertPreTrainingHeads(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.predictions = BertLMPredictionHead(config)\n        self.seq_relationship = nn.Linear(config.hidden_size, 2)\n\n    def forward(self, sequence_output, pooled_output):\n        prediction_scores = self.predictions(sequence_output)\n        seq_relationship_score = self.seq_relationship(pooled_output)\n        return prediction_scores, seq_relationship_score\n\n\nclass BertPreTrainedModel(PreTrainedModel):\n    \"\"\"\n    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n    models.\n    \"\"\"\n\n    config_class = BertConfig\n    load_tf_weights = load_tf_weights_in_bert\n    base_model_prefix = \"bert\"\n    _keys_to_ignore_on_load_missing = [r\"position_ids\"]\n\n    def _init_weights(self, module):\n        \"\"\" Initialize the weights \"\"\"\n        if isinstance(module, (nn.Linear, nn.Embedding)):\n            # Slightly different from the TF version which uses truncated_normal for initialization\n            # cf https://github.com/pytorch/pytorch/pull/5617\n            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n        elif isinstance(module, nn.LayerNorm):\n            module.bias.data.zero_()\n            module.weight.data.fill_(1.0)\n        if isinstance(module, nn.Linear) and module.bias is not None:\n            module.bias.data.zero_()\n\n\n@dataclass\nclass BertForPreTrainingOutput(ModelOutput):\n    \"\"\"\n    Output type of :class:`~transformers.BertForPreTraining`.\n    Args:\n        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):\n            Total loss as the sum of the masked language modeling loss and the next sequence prediction\n            (classification) loss.\n        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):\n            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).\n        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):\n            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation\n            before SoftMax).\n        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):\n            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)\n            of shape :obj:`(batch_size, sequence_length, hidden_size)`.\n            Hidden-states of the model at the output of each layer plus the initial embedding outputs.\n        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):\n            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,\n            sequence_length, sequence_length)`.\n            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n            heads.\n    \"\"\"\n\n    loss: Optional[torch.FloatTensor] = None\n    prediction_logits: torch.FloatTensor = None\n    seq_relationship_logits: torch.FloatTensor = None\n    hidden_states: Optional[Tuple[torch.FloatTensor]] = None\n    attentions: Optional[Tuple[torch.FloatTensor]] = None\n\n\nBERT_START_DOCSTRING = r\"\"\"\n    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic\n    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,\n    pruning heads etc.)\n    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__\n    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to\n    general usage and behavior.\n    Parameters:\n        config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.\n            Initializing with a config file does not load the weights associated with the model, only the\n            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model\n            weights.\n\"\"\"\n\nBERT_INPUTS_DOCSTRING = r\"\"\"\n    Args:\n        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):\n            Indices of input sequence tokens in the vocabulary.\n            Indices can be obtained using :class:`~transformers.BertTokenizer`. See\n            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for\n            details.\n            `What are input IDs? <../glossary.html#input-ids>`__\n        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):\n            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n            `What are attention masks? <../glossary.html#attention-mask>`__\n        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):\n            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,\n            1]``:\n            - 0 corresponds to a `sentence A` token,\n            - 1 corresponds to a `sentence B` token.\n            `What are token type IDs? <../glossary.html#token-type-ids>`_\n        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):\n            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,\n            config.max_position_embeddings - 1]``.\n            `What are position IDs? <../glossary.html#position-ids>`_\n        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):\n            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:\n            - 1 indicates the head is **not masked**,\n            - 0 indicates the head is **masked**.\n        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):\n            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.\n            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated\n            vectors than the model's internal embedding lookup matrix.\n        output_attentions (:obj:`bool`, `optional`):\n            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned\n            tensors for more detail.\n        output_hidden_states (:obj:`bool`, `optional`):\n            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for\n            more detail.\n        return_dict (:obj:`bool`, `optional`):\n            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.\n\"\"\"\n\n\n@add_start_docstrings(\n    \"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.\",\n    BERT_START_DOCSTRING,\n)\nclass BertModel(BertPreTrainedModel):\n    \"\"\"\n    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of\n    cross-attention is added between the self-attention layers, following the architecture described in `Attention is\n    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,\n    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.\n    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an\n    input to the forward pass.\n    \"\"\"\n\n    def __init__(self, config, add_pooling_layer=True):\n        super().__init__(config)\n        self.config = config\n\n        self.embeddings = BertEmbeddings(config)\n        \n        self.encoder = BertEncoder(config)\n\n        self.pooler = BertPooler(config) if add_pooling_layer else None\n\n        self.init_weights()\n \n\n    def get_input_embeddings(self):\n        return self.embeddings.word_embeddings\n\n    def set_input_embeddings(self, value):\n        self.embeddings.word_embeddings = value\n\n    def _prune_heads(self, heads_to_prune):\n        \"\"\"\n        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base\n        class PreTrainedModel\n        \"\"\"\n        for layer, heads in heads_to_prune.items():\n            self.encoder.layer[layer].attention.prune_heads(heads)\n\n    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=\"bert-base-uncased\",\n        output_type=BaseModelOutputWithPoolingAndCrossAttentions,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    \n    \n    def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:\n        \"\"\"\n        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.\n\n        Arguments:\n            attention_mask (:obj:`torch.Tensor`):\n                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.\n            input_shape (:obj:`Tuple[int]`):\n                The shape of the input to the model.\n            device: (:obj:`torch.device`):\n                The device of the input to the model.\n\n        Returns:\n            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.\n        \"\"\"\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        if attention_mask.dim() == 3:\n            extended_attention_mask = attention_mask[:, None, :, :]\n        elif attention_mask.dim() == 2:\n            # Provided a padding mask of dimensions [batch_size, seq_length]\n            # - if the model is a decoder, apply a causal mask in addition to the padding mask\n            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]\n            if is_decoder:\n                batch_size, seq_length = input_shape\n                seq_ids = torch.arange(seq_length, device=device)\n                causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]\n                # in case past_key_values are used we need to add a prefix ones mask to the causal mask\n                # causal and attention masks must have same type with pytorch version < 1.3\n                causal_mask = causal_mask.to(attention_mask.dtype)\n\n                if causal_mask.shape[1] < attention_mask.shape[1]:\n                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]\n                    causal_mask = torch.cat(\n                        [\n                            torch.ones(\n                                (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype\n                            ),\n                            causal_mask,\n                        ],\n                        axis=-1,\n                    )\n\n                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]\n            else:\n                extended_attention_mask = attention_mask[:, None, None, :]\n        else:\n            raise ValueError(\n                \"Wrong shape for input_ids (shape {}) or attention_mask (shape {})\".format(\n                    input_shape, attention_mask.shape\n                )\n            )\n\n        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for\n        # masked positions, this operation will create a tensor which is 0.0 for\n        # positions we want to attend and -10000.0 for masked positions.\n        # Since we are adding it to the raw scores before the softmax, this is\n        # effectively the same as removing these entirely.\n        extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility\n        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n        return extended_attention_mask\n    \n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        is_decoder=False,\n        mode='multi_modal',\n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        \"\"\"\n        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n        output_hidden_states = (\n            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n        )\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        if is_decoder:\n            use_cache = use_cache if use_cache is not None else self.config.use_cache\n        else:\n            use_cache = False\n\n        if input_ids is not None and inputs_embeds is not None:\n            raise ValueError(\"You cannot specify both input_ids and inputs_embeds at the same time\")\n        elif input_ids is not None:\n            input_shape = input_ids.size()\n            batch_size, seq_length = input_shape\n            device = input_ids.device\n        elif inputs_embeds is not None:\n            input_shape = inputs_embeds.size()[:-1]\n            batch_size, seq_length = input_shape\n            device = inputs_embeds.device\n        elif encoder_embeds is not None:    \n            input_shape = encoder_embeds.size()[:-1]\n            batch_size, seq_length = input_shape \n            device = encoder_embeds.device\n        else:\n            raise ValueError(\"You have to specify either input_ids or inputs_embeds or encoder_embeds\")\n\n        # past_key_values_length\n        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0\n\n        if attention_mask is None:\n            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)\n        if token_type_ids is None:\n            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)\n\n        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]\n        # ourselves in which case we just need to make it broadcastable to all heads.\n        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, \n                                                                                 device, is_decoder)\n\n        # If a 2D or 3D attention mask is provided for the cross-attention\n        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]\n        if encoder_hidden_states is not None:\n            if type(encoder_hidden_states) == list:\n                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()\n            else:\n                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()\n            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)\n            \n            if type(encoder_attention_mask) == list:\n                encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]\n            elif encoder_attention_mask is None:\n                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)\n                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)\n            else:    \n                encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)\n        else:\n            encoder_extended_attention_mask = None\n\n        # Prepare head mask if needed\n        # 1.0 in head_mask indicate we keep the head\n        # attention_probs has shape bsz x n_heads x N x N\n        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]\n        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]\n        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)\n        \n        if encoder_embeds is None:\n            embedding_output = self.embeddings(\n                input_ids=input_ids,\n                position_ids=position_ids,\n                token_type_ids=token_type_ids,\n                inputs_embeds=inputs_embeds,\n                past_key_values_length=past_key_values_length,\n            )\n        else:\n            embedding_output = encoder_embeds\n            \n        encoder_outputs = self.encoder(\n            embedding_output,\n            attention_mask=extended_attention_mask,\n            head_mask=head_mask,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_extended_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            mode=mode,\n        )\n        sequence_output = encoder_outputs[0]\n        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None\n\n        if not return_dict:\n            return (sequence_output, pooled_output) + encoder_outputs[1:]\n\n        return BaseModelOutputWithPoolingAndCrossAttentions(\n            last_hidden_state=sequence_output,\n            pooler_output=pooled_output,\n            past_key_values=encoder_outputs.past_key_values,\n            hidden_states=encoder_outputs.hidden_states,\n            attentions=encoder_outputs.attentions,\n            cross_attentions=encoder_outputs.cross_attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next\n    sentence prediction (classification)` head.\n    \"\"\",\n    BERT_START_DOCSTRING,\n)\nclass BertForPreTraining(BertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config)\n        self.cls = BertPreTrainingHeads(config)\n\n        self.init_weights()\n\n    def get_output_embeddings(self):\n        return self.cls.predictions.decoder\n\n    def set_output_embeddings(self, new_embeddings):\n        self.cls.predictions.decoder = new_embeddings\n\n    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        next_sentence_label=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):\n            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,\n            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored\n            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``\n        next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):\n            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair\n            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:\n            - 0 indicates sequence B is a continuation of sequence A,\n            - 1 indicates sequence B is a random sequence.\n        kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):\n            Used to hide legacy arguments that have been deprecated.\n        Returns:\n        Example::\n            >>> from transformers import BertTokenizer, BertForPreTraining\n            >>> import torch\n            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n            >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')\n            >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n            >>> outputs = model(**inputs)\n            >>> prediction_logits = outputs.prediction_logits\n            >>> seq_relationship_logits = outputs.seq_relationship_logits\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output, pooled_output = outputs[:2]\n        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)\n\n        total_loss = None\n        if labels is not None and next_sentence_label is not None:\n            loss_fct = CrossEntropyLoss()\n            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))\n            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))\n            total_loss = masked_lm_loss + next_sentence_loss\n\n        if not return_dict:\n            output = (prediction_scores, seq_relationship_score) + outputs[2:]\n            return ((total_loss,) + output) if total_loss is not None else output\n\n        return BertForPreTrainingOutput(\n            loss=total_loss,\n            prediction_logits=prediction_scores,\n            seq_relationship_logits=seq_relationship_score,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"Bert Model with a `language modeling` head on top for CLM fine-tuning. \"\"\", BERT_START_DOCSTRING\n)\nclass BertLMHeadModel(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.cls = BertOnlyMLMHead(config)\n\n        self.init_weights()\n\n    def get_output_embeddings(self):\n        return self.cls.predictions.decoder\n\n    def set_output_embeddings(self, new_embeddings):\n        self.cls.predictions.decoder = new_embeddings\n\n    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        labels=None,\n        past_key_values=None,\n        use_cache=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        is_decoder=True,\n        reduction='mean',\n        mode='multi_modal',\n        soft_labels=None,\n        alpha=0,\n        return_logits=False,        \n    ):\n        r\"\"\"\n        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):\n            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if\n            the model is configured as a decoder.\n        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in\n            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:\n            - 1 for tokens that are **not masked**,\n            - 0 for tokens that are **masked**.\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in\n            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are\n            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``\n        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):\n            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.\n            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`\n            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`\n            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.\n        use_cache (:obj:`bool`, `optional`):\n            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up\n            decoding (see :obj:`past_key_values`).\n        Returns:\n        Example::\n            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig\n            >>> import torch\n            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n            >>> config = BertConfig.from_pretrained(\"bert-base-cased\")\n            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)\n            >>> inputs = tokenizer(\"Hello, my dog is cute\", return_tensors=\"pt\")\n            >>> outputs = model(**inputs)\n            >>> prediction_logits = outputs.logits\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        if labels is not None:\n            use_cache = False\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            past_key_values=past_key_values,\n            use_cache=use_cache,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            is_decoder=is_decoder,\n            mode=mode,\n        )\n        \n        sequence_output = outputs[0]\n        prediction_scores = self.cls(sequence_output)\n        \n        if return_logits:\n            return prediction_scores[:, :-1, :].contiguous()  \n\n        lm_loss = None\n        if labels is not None:\n            # we are doing next-token prediction; shift prediction scores and input ids by one\n            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()\n            labels = labels[:, 1:].contiguous()\n            loss_fct = CrossEntropyLoss(reduction=reduction)\n            lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))\n            lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)\n            \n        if soft_labels is not None:\n            loss_distill = -torch.sum(F.log_softmax(shifted_prediction_scores, dim=1)*soft_labels,dim=-1)\n            loss_distill = (loss_distill * (labels!=-100)).sum(1)\n            lm_loss = (1-alpha)*lm_loss + alpha*loss_distill                    \n\n        if not return_dict:\n            output = (prediction_scores,) + outputs[2:]\n            return ((lm_loss,) + output) if lm_loss is not None else output\n\n        return CausalLMOutputWithCrossAttentions(\n            loss=lm_loss,\n            logits=prediction_scores,\n            past_key_values=outputs.past_key_values,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            cross_attentions=outputs.cross_attentions,\n        )\n\n    def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):\n        input_shape = input_ids.shape\n        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly\n        if attention_mask is None:\n            attention_mask = input_ids.new_ones(input_shape)\n\n        # cut decoder_input_ids if past is used\n        if past is not None:\n            input_ids = input_ids[:, -1:]\n\n        return {\n            \"input_ids\": input_ids, \n            \"attention_mask\": attention_mask, \n            \"past_key_values\": past,\n            \"encoder_hidden_states\": model_kwargs.get(\"encoder_hidden_states\", None),\n            \"encoder_attention_mask\": model_kwargs.get(\"encoder_attention_mask\", None),\n            \"is_decoder\": True,\n        }\n\n    def _reorder_cache(self, past, beam_idx):\n        reordered_past = ()\n        for layer_past in past:\n            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)\n        return reordered_past\n\n\n@add_start_docstrings(\"\"\"Bert Model with a `language modeling` head on top. \"\"\", BERT_START_DOCSTRING)\nclass BertForMaskedLM(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.cls = BertOnlyMLMHead(config)\n\n        self.init_weights()\n\n    def get_output_embeddings(self):\n        return self.cls.predictions.decoder\n\n    def set_output_embeddings(self, new_embeddings):\n        self.cls.predictions.decoder = new_embeddings\n\n    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=\"bert-base-uncased\",\n        output_type=MaskedLMOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        encoder_embeds=None,\n        encoder_hidden_states=None,\n        encoder_attention_mask=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        is_decoder=False,\n        mode='multi_modal',\n        soft_labels=None,\n        alpha=0,\n        return_logits=False,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,\n            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored\n            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``\n        \"\"\"\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_embeds=encoder_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            is_decoder=is_decoder,\n            mode=mode,\n        )\n\n        sequence_output = outputs[0]\n        prediction_scores = self.cls(sequence_output)\n        \n        if return_logits:\n            return prediction_scores\n\n        masked_lm_loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()  # -100 index = padding token\n            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))\n        \n        if soft_labels is not None:\n            loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=1)*soft_labels,dim=-1)\n            loss_distill = loss_distill[labels!=-100].mean()\n            masked_lm_loss = (1-alpha)*masked_lm_loss + alpha*loss_distill\n\n        if not return_dict:\n            output = (prediction_scores,) + outputs[2:]\n            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output\n\n        return MaskedLMOutput(\n            loss=masked_lm_loss,\n            logits=prediction_scores,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n    def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):\n        input_shape = input_ids.shape\n        effective_batch_size = input_shape[0]\n\n        #  add a dummy token\n        assert self.config.pad_token_id is not None, \"The PAD token should be defined for generation\"\n        attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)\n        dummy_token = torch.full(\n            (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device\n        )\n        input_ids = torch.cat([input_ids, dummy_token], dim=1)\n\n        return {\"input_ids\": input_ids, \"attention_mask\": attention_mask}\n\n\n@add_start_docstrings(\n    \"\"\"Bert Model with a `next sentence prediction (classification)` head on top. \"\"\",\n    BERT_START_DOCSTRING,\n)\nclass BertForNextSentencePrediction(BertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config)\n        self.cls = BertOnlyNSPHead(config)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n        **kwargs\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair\n            (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:\n            - 0 indicates sequence B is a continuation of sequence A,\n            - 1 indicates sequence B is a random sequence.\n        Returns:\n        Example::\n            >>> from transformers import BertTokenizer, BertForNextSentencePrediction\n            >>> import torch\n            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n            >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')\n            >>> prompt = \"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.\"\n            >>> next_sentence = \"The sky is blue due to the shorter wavelength of blue light.\"\n            >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')\n            >>> outputs = model(**encoding, labels=torch.LongTensor([1]))\n            >>> logits = outputs.logits\n            >>> assert logits[0, 0] < logits[0, 1] # next sentence was random\n        \"\"\"\n\n        if \"next_sentence_label\" in kwargs:\n            warnings.warn(\n                \"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.\",\n                FutureWarning,\n            )\n            labels = kwargs.pop(\"next_sentence_label\")\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        pooled_output = outputs[1]\n\n        seq_relationship_scores = self.cls(pooled_output)\n\n        next_sentence_loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))\n\n        if not return_dict:\n            output = (seq_relationship_scores,) + outputs[2:]\n            return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output\n\n        return NextSentencePredictorOutput(\n            loss=next_sentence_loss,\n            logits=seq_relationship_scores,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled\n    output) e.g. for GLUE tasks.\n    \"\"\",\n    BERT_START_DOCSTRING,\n)\nclass BertForSequenceClassification(BertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.bert = BertModel(config)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=\"bert-base-uncased\",\n        output_type=SequenceClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,\n            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),\n            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        pooled_output = outputs[1]\n\n        pooled_output = self.dropout(pooled_output)\n        logits = self.classifier(pooled_output)\n\n        loss = None\n        if labels is not None:\n            if self.num_labels == 1:\n                #  We are doing regression\n                loss_fct = MSELoss()\n                loss = loss_fct(logits.view(-1), labels.view(-1))\n            else:\n                loss_fct = CrossEntropyLoss()\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return SequenceClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a\n    softmax) e.g. for RocStories/SWAG tasks.\n    \"\"\",\n    BERT_START_DOCSTRING,\n)\nclass BertForMultipleChoice(BertPreTrainedModel):\n    def __init__(self, config):\n        super().__init__(config)\n\n        self.bert = BertModel(config)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n        self.classifier = nn.Linear(config.hidden_size, 1)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, num_choices, sequence_length\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=\"bert-base-uncased\",\n        output_type=MultipleChoiceModelOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,\n            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See\n            :obj:`input_ids` above)\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]\n\n        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None\n        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None\n        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None\n        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None\n        inputs_embeds = (\n            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))\n            if inputs_embeds is not None\n            else None\n        )\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        pooled_output = outputs[1]\n\n        pooled_output = self.dropout(pooled_output)\n        logits = self.classifier(pooled_output)\n        reshaped_logits = logits.view(-1, num_choices)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            loss = loss_fct(reshaped_logits, labels)\n\n        if not return_dict:\n            output = (reshaped_logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return MultipleChoiceModelOutput(\n            loss=loss,\n            logits=reshaped_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for\n    Named-Entity-Recognition (NER) tasks.\n    \"\"\",\n    BERT_START_DOCSTRING,\n)\nclass BertForTokenClassification(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n        self.classifier = nn.Linear(config.hidden_size, config.num_labels)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=\"bert-base-uncased\",\n        output_type=TokenClassifierOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        labels=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -\n            1]``.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        sequence_output = self.dropout(sequence_output)\n        logits = self.classifier(sequence_output)\n\n        loss = None\n        if labels is not None:\n            loss_fct = CrossEntropyLoss()\n            # Only keep active parts of the loss\n            if attention_mask is not None:\n                active_loss = attention_mask.view(-1) == 1\n                active_logits = logits.view(-1, self.num_labels)\n                active_labels = torch.where(\n                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)\n                )\n                loss = loss_fct(active_logits, active_labels)\n            else:\n                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n\n        if not return_dict:\n            output = (logits,) + outputs[2:]\n            return ((loss,) + output) if loss is not None else output\n\n        return TokenClassifierOutput(\n            loss=loss,\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\n@add_start_docstrings(\n    \"\"\"\n    Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear\n    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).\n    \"\"\",\n    BERT_START_DOCSTRING,\n)\nclass BertForQuestionAnswering(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.num_labels = config.num_labels\n\n        self.bert = BertModel(config, add_pooling_layer=False)\n        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)\n\n        self.init_weights()\n\n    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format(\"batch_size, sequence_length\"))\n    @add_code_sample_docstrings(\n        tokenizer_class=_TOKENIZER_FOR_DOC,\n        checkpoint=\"bert-base-uncased\",\n        output_type=QuestionAnsweringModelOutput,\n        config_class=_CONFIG_FOR_DOC,\n    )\n    def forward(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        token_type_ids=None,\n        position_ids=None,\n        head_mask=None,\n        inputs_embeds=None,\n        start_positions=None,\n        end_positions=None,\n        output_attentions=None,\n        output_hidden_states=None,\n        return_dict=None,\n    ):\n        r\"\"\"\n        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the start of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):\n            Labels for position (index) of the end of the labelled span for computing the token classification loss.\n            Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the\n            sequence are not taken into account for computing the loss.\n        \"\"\"\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n        outputs = self.bert(\n            input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        sequence_output = outputs[0]\n\n        logits = self.qa_outputs(sequence_output)\n        start_logits, end_logits = logits.split(1, dim=-1)\n        start_logits = start_logits.squeeze(-1)\n        end_logits = end_logits.squeeze(-1)\n\n        total_loss = None\n        if start_positions is not None and end_positions is not None:\n            # If we are on multi-GPU, split add a dimension\n            if len(start_positions.size()) > 1:\n                start_positions = start_positions.squeeze(-1)\n            if len(end_positions.size()) > 1:\n                end_positions = end_positions.squeeze(-1)\n            # sometimes the start/end positions are outside our model inputs, we ignore these terms\n            ignored_index = start_logits.size(1)\n            start_positions.clamp_(0, ignored_index)\n            end_positions.clamp_(0, ignored_index)\n\n            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)\n            start_loss = loss_fct(start_logits, start_positions)\n            end_loss = loss_fct(end_logits, end_positions)\n            total_loss = (start_loss + end_loss) / 2\n\n        if not return_dict:\n            output = (start_logits, end_logits) + outputs[2:]\n            return ((total_loss,) + output) if total_loss is not None else output\n\n        return QuestionAnsweringModelOutput(\n            loss=total_loss,\n            start_logits=start_logits,\n            end_logits=end_logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "src/optimization/adamw.py",
    "content": "\"\"\"\nAdamW optimizer (weight decay fix)\ncopied from hugginface\n\"\"\"\nimport math\n\nimport torch\nfrom torch.optim import Optimizer\n\n\nclass AdamW(Optimizer):\n    \"\"\" Implements Adam algorithm with weight decay fix.\n    Parameters:\n        lr (float): learning rate. Default 1e-3.\n        betas (tuple of 2 floats): Adams beta parameters (b1, b2).\n            Default: (0.9, 0.999)\n        eps (float): Adams epsilon. Default: 1e-6\n        weight_decay (float): Weight decay. Default: 0.0\n        correct_bias (bool): can be set to False to avoid correcting bias\n            in Adam (e.g. like in Bert TF repository). Default True.\n    \"\"\"\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,\n                 weight_decay=0.0, correct_bias=True):\n        if lr < 0.0:\n            raise ValueError(\n                \"Invalid learning rate: {} - should be >= 0.0\".format(lr))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter: {} - \"\n                             \"should be in [0.0, 1.0[\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter: {} - \"\n                             \"should be in [0.0, 1.0[\".format(betas[1]))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {} - \"\n                             \"should be >= 0.0\".format(eps))\n        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,\n                        correct_bias=correct_bias)\n        super(AdamW, self).__init__(params, defaults)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError(\n                        'Adam does not support sparse '\n                        'gradients, please consider SparseAdam instead')\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state['step'] = 0\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                state['step'] += 1\n\n                # Decay the first and second moment running average coefficient\n                # In-place operations to update the averages at the same time\n                exp_avg.mul_(beta1).add_(1.0 - beta1, grad)\n                exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)\n                denom = exp_avg_sq.sqrt().add_(group['eps'])\n\n                step_size = group['lr']\n                if group['correct_bias']:  # No bias correction for Bert\n                    bias_correction1 = 1.0 - beta1 ** state['step']\n                    bias_correction2 = 1.0 - beta2 ** state['step']\n                    step_size = (step_size * math.sqrt(bias_correction2)\n                                 / bias_correction1)\n\n                p.data.addcdiv_(-step_size, exp_avg, denom)\n\n                # Just adding the square of the weights to the loss function is\n                # *not* the correct way of using L2 regularization/weight decay\n                # with Adam, since that will interact with the m and v\n                # parameters in strange ways.\n                #\n                # Instead we want to decay the weights in a manner that doesn't\n                # interact with the m/v parameters. This is equivalent to\n                # adding the square of the weights to the loss with plain\n                # (non-momentum) SGD.\n                # Add weight decay at the end (fixed version)\n                if group['weight_decay'] > 0.0:\n                    p.data.add_(-group['lr'] * group['weight_decay'], p.data)\n\n        return loss\n"
  },
  {
    "path": "src/optimization/sched.py",
    "content": "\"\"\"\noptimizer learning rate scheduling helpers\n\"\"\"\nfrom math import ceil\nfrom collections import Counter\n\n\ndef noam_schedule(step, warmup_step=4000):\n    if step <= warmup_step:\n        return step / warmup_step\n    return (warmup_step ** 0.5) * (step ** -0.5)\n\n\ndef warmup_linear(step, warmup_step, tot_step):\n    if step < warmup_step:\n        return step / warmup_step\n    return max(0, (tot_step-step)/(tot_step-warmup_step))\n\n\ndef multi_step_schedule(n_epoch, milestones, gamma=0.5):\n    milestones = list(sorted(milestones))\n    for i, m in enumerate(milestones):\n        if n_epoch < m:\n            return gamma**i\n    return gamma**(len(milestones)+1)\n\n\ndef get_lr_sched(global_step, decay, learning_rate,\n                 num_train_steps, warmup_ratio=0.1,\n                 decay_epochs=[], multi_step_epoch=-1):\n    warmup_steps = int(warmup_ratio*num_train_steps)\n    if decay == 'linear':\n        lr_this_step = learning_rate * warmup_linear(\n            global_step, warmup_steps, num_train_steps)\n    elif decay == 'invsqrt':\n        lr_this_step = learning_rate * noam_schedule(\n            global_step, warmup_steps)\n    elif decay == 'constant':\n        lr_this_step = learning_rate\n    elif decay == \"multi_step\":\n        assert multi_step_epoch >= 0\n        lr_this_step = learning_rate * multi_step_schedule(\n            multi_step_epoch, decay_epochs)\n    if lr_this_step <= 0:\n        # save guard for possible miscalculation of train steps\n        lr_this_step = 1e-8\n    return lr_this_step\n"
  },
  {
    "path": "src/optimization/utils.py",
    "content": "from torch.optim import Adam, Adamax, SGD\nfrom src.optimization.adamw import AdamW\n\n\ndef setup_e2e_optimizer(model, opts):\n    if opts.optim == 'adam':\n        OptimCls = Adam\n    elif opts.optim == 'adamax':\n        OptimCls = Adamax\n    elif opts.optim == 'adamw':\n        OptimCls = AdamW\n    else:\n        raise ValueError('invalid optimizer')\n    optimizer = OptimCls(model.parameters(), lr=opts.learning_rate, betas=opts.betas)\n\n    return optimizer\n"
  },
  {
    "path": "src/pretrain/run_pretrain_contrastive_only.py",
    "content": "import os\n\nimport torch\nimport time\nimport random\nimport pprint\nimport math\nimport json\nfrom transformers import BertConfig, BertTokenizerFast\n\nfrom src.datasets.dataset_pretrain_sparse import AlproPretrainSparseDataset, PretrainImageTextDataset, PretrainCollator\nfrom src.datasets.dataloader import MetaLoader, PrefetchLoader\nfrom src.datasets.data_utils import ImageNorm, mk_input_group\nfrom torch.utils.data import DataLoader\nfrom torch.nn.utils import clip_grad_norm_\nfrom src.configs.config import shared_configs\nfrom src.utils.misc import set_random_seed, NoOp, zero_none_grad\nfrom src.utils.logger import LOGGER, TB_LOGGER, add_log_to_file, RunningMeter\nfrom src.utils.basic_utils import load_jsonl, load_json, read_dataframe\nfrom src.utils.load_save import (ModelSaver,\n                                 save_training_meta,\n                                 load_state_dict_with_pos_embed_resizing)\nfrom src.utils.load_save import E2E_TrainingRestorer as TrainingRestorer\nfrom src.optimization.sched import get_lr_sched\nfrom src.optimization.utils import setup_e2e_optimizer\nfrom collections import defaultdict\nfrom tqdm import tqdm\nfrom os.path import join\nfrom apex import amp\nfrom torch.utils.data.distributed import DistributedSampler\nimport horovod.torch as hvd\nfrom src.utils.distributed import all_gather_list\n\nfrom src.modeling.alpro_models import Prompter\n\n\ndef mk_captions_pretrain_dataloader(dataset_name, anno_path, video_dir, txt_dir, cfg, tokenizer, \n                                    is_train=True, max_txt_len=80):\n    # make a list(dict), where each dict {vis_id: int, txt: str}\n    if dataset_name == \"webvid2m\":\n        datalist = read_dataframe(anno_path)\n\n        datalist = datalist[datalist['txt_len'] < max_txt_len]\n        LOGGER.info('Found {} entries for webvid2m'.format(len(datalist)))\n    \n    elif dataset_name == \"cc3m\":\n        datalist = json.load(open(anno_path))\n        LOGGER.info('Found {} entries for cc3m'.format(len(datalist)))\n\n    else:\n        raise ValueError(\"Invalid dataset_name\")\n\n    if dataset_name in [\"webvid2m\"]:\n        frm_sampling_strategy = cfg.frm_sampling_strategy\n        if not is_train and frm_sampling_strategy == \"rand\":\n            frm_sampling_strategy = \"uniform\"\n        dataset = AlproPretrainSparseDataset(\n            datalist=datalist,\n            tokenizer=tokenizer,\n            img_lmdb_dir=video_dir,\n            img_db_type='rawvideo',\n            txt_dir=txt_dir,\n            crop_size=cfg.crop_img_size,\n            resize_size=cfg.resize_size,\n            max_txt_len=cfg.max_txt_len,\n            use_itm=cfg.use_itm,\n            fps=cfg.fps,\n            num_frm=cfg.num_frm,\n            frm_sampling_strategy=frm_sampling_strategy,\n            is_train=is_train\n            # vis_format=vis_format\n        )\n    elif dataset_name in [\"cc3m\"]:\n        dataset = PretrainImageTextDataset(datalist=datalist, \n                                           tokenizer=tokenizer,\n                                           crop_size=cfg.crop_img_size,\n                                           resize_size=cfg.resize_size,\n                                           max_txt_len=cfg.max_txt_len,\n                                           num_frm=cfg.num_frm\n                                           )\n\n    LOGGER.info(f\"[{dataset_name}] is_train {is_train} \"\n                f\"dataset size {len(dataset)}, \")\n    batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size\n    sampler = DistributedSampler(\n        dataset, num_replicas=hvd.size(), rank=hvd.rank(),\n        shuffle=is_train)\n    data_collator = PretrainCollator(tokenizer=tokenizer,\n                                    mlm=cfg.use_mlm,\n                                    mlm_probability=0.15,\n                                    max_length=cfg.max_txt_len,\n                                    is_train=is_train)\n    dataloader = DataLoader(dataset,\n                            batch_size=batch_size,\n                            shuffle=False,\n                            sampler=sampler,\n                            num_workers=cfg.n_workers,\n                            pin_memory=cfg.pin_mem,\n                            collate_fn=data_collator.collate_batch)\n\n    return dataloader\n\n\ndef setup_dataloaders(cfg, tokenizer):\n    LOGGER.info(\"Init. train_loader and val_loader...\")\n\n    train_loaders = {}\n    for db in cfg.train_datasets:\n        train_loaders[db.name] = mk_captions_pretrain_dataloader(\n            dataset_name=db.name,\n            anno_path=db.ann, video_dir=db.img, txt_dir=db.txt,\n            cfg=cfg, tokenizer=tokenizer, is_train=True\n        )\n\n    val_loaders = {}\n    for db in cfg.val_datasets:\n        val_loaders[db.name] = mk_captions_pretrain_dataloader(\n            dataset_name=db.name,\n            anno_path=db.ann, video_dir=db.img, txt_dir=db.txt,\n            cfg=cfg, tokenizer=tokenizer, is_train=False\n        )\n    return train_loaders, val_loaders\n\n\ndef setup_model(cfg, device=None):\n    LOGGER.info(\"Setup model...\")\n    # has to be a BertConfig instance\n    model_cfg = load_json(cfg.model_config)\n    model_cfg = BertConfig(**model_cfg)\n    # add model-specific config\n    add_attr_list = [\n        \"max_n_example_per_group\",\n        \"num_entities\"\n    ]\n    for k in add_attr_list:\n        setattr(model_cfg, k, cfg[k])\n    LOGGER.info(f\"model_cfg {pprint.pformat(model_cfg.to_dict())}\")\n\n    LOGGER.info(\"setup e2e model\")\n\n    if cfg.model_type == 'pretrain':\n        # initialize cnn config\n        video_enc_cfg = load_json(cfg.visual_model_cfg)\n\n        video_enc_cfg['num_frm'] = cfg.num_frm\n        video_enc_cfg['img_size'] = cfg.crop_img_size\n\n        model = Prompter(\n            model_cfg, \n            input_format=cfg.img_input_format,\n            video_enc_cfg=video_enc_cfg\n            )\n        if cfg.e2e_weights_path:\n            LOGGER.info(f\"Loading e2e weights from {cfg.e2e_weights_path}\")\n            num_patches = (cfg.crop_img_size // video_enc_cfg['patch_size']) ** 2\n            # NOTE strict if False if loaded from ALBEF ckpt\n            load_state_dict_with_pos_embed_resizing(model, \n                                                    cfg.e2e_weights_path, \n                                                    num_patches=num_patches, \n                                                    num_frames=cfg.num_frm, \n                                                    strict=not cfg.albef_init\n                                                    )\n            \n        else:\n            LOGGER.info(f\"Loading visual weights from {cfg.visual_weights_path}\")\n            LOGGER.info(f\"Loading bert weights from {cfg.bert_weights_path}\")\n            model.load_separate_ckpt(\n                visual_weights_path=cfg.visual_weights_path,\n                bert_weights_path=cfg.bert_weights_path\n            )\n    else:\n        raise NotImplementedError(f\"cfg.model_type not found {cfg.model_type}.\")\n\n    # if cfg.freeze_cnn:\n    #     model.freeze_cnn_backbone()\n    \n    LOGGER.info(\"Moving model to device\") \n    model.to(device)\n    LOGGER.info(\"Completed moving model to device.\") \n\n    LOGGER.info(\"Setup model done!\")\n    return model\n\n\ndef forward_step(cfg, model, batch):\n    \"\"\"shared for training and validation\"\"\"\n    # used to make visual feature copies\n    if not cfg.use_itm:\n        batch[\"itm_labels\"] = None\n    outputs = model(batch)  # dict\n    return outputs\n\n\n@torch.no_grad()\ndef validate(model, val_loader, cfg):\n    model.eval()\n\n    n_itc_ex = 0\n    n_t2i_corrects = 0\n    n_i2t_corrects = 0\n\n    itc_loss = 0\n    st = time.time()\n    val_log = {'valid/itc_loss': 0,\n               'valid/i2t_acc': 0,\n               'valid/t2i_acc': 0\n               }\n\n    debug_step = 5\n    val_loaders = val_loader if isinstance(val_loader, dict) else {\n        \"unnamed_val_loader\": val_loader}\n    \n    total_val_iters = 0 \n\n    LOGGER.info(f\"In total {len(val_loaders)} val loaders\")\n    for loader_name, val_loader in val_loaders.items():\n        LOGGER.info(f\"Loop val_loader {loader_name}.\")\n\n        total_val_iters += len(val_loader)\n        for val_step, batch in enumerate(val_loader):\n            # use iter to reset MetaLoader\n            # forward pass\n            outputs = forward_step(cfg, model, batch)\n\n            assert not cfg.use_itm and not cfg.use_mlm\n\n            if cfg.use_itc:\n                itc_loss += outputs[\"itc_loss\"].sum().item()\n\n            if cfg.debug and val_step >= debug_step:\n                break\n\n    # Gather across all processes\n    all_gather_itc_loss = all_gather_list(itc_loss)\n    itc_loss = sum(all_gather_itc_loss)\n\n    # FIXME check this whether take mean?\n    assert cfg.use_itc, 'cfg.use_itc is False for contrastive-only pretraining.'\n\n    val_log.update({\n        'valid/itc_loss': float(itc_loss),\n    })\n\n    n_itc_ex += len(outputs[\"itc_labels\"])\n    n_t2i_corrects += (\n            outputs[\"t2i_scores\"].max(\n                dim=-1)[1] == outputs[\"itc_labels\"]).sum().item()\n    n_i2t_corrects += (\n            outputs[\"i2t_scores\"].max(\n                dim=-1)[1] == outputs[\"itc_labels\"]).sum().item()\n\n    n_i2t_corrects = sum(all_gather_list(n_i2t_corrects))\n    n_t2i_corrects = sum(all_gather_list(n_t2i_corrects))\n\n    n_itc_ex = sum(all_gather_list(n_itc_ex))\n\n    if n_itc_ex != 0:\n        val_log.update({\n            'valid/i2t_acc': float(n_i2t_corrects / n_itc_ex),\n            'valid/t2i_acc': float(n_t2i_corrects / n_itc_ex)\n        })\n\n    TB_LOGGER.log_scalar_dict(val_log)\n    LOGGER.info(f\"validation finished in {int(time.time() - st)} seconds, \")\n\n    LOGGER.info(\"[itc_loss]: {} \".format(itc_loss))\n    LOGGER.info(\"In total, {} validation iters.\".format(total_val_iters))\n\n    model.train()\n    return val_log\n\n\ndef start_training():\n    cfg = shared_configs.get_sparse_pretraining_args()\n    set_random_seed(cfg.seed)\n\n    n_gpu = hvd.size()\n    # device = torch.device(\"cuda\", hvd.local_rank())\n    # torch.cuda.set_device(hvd.local_rank())\n\n    # This resolves the issue GPU 0 always has more processes running and more GPU-RAM.\n    # c.f. https://github.com/horovod/horovod/issues/2625#issuecomment-868134876\n    os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank())\n    device = torch.device(\"cuda\", 0)\n    torch.cuda.set_device(0)\n\n    if hvd.rank() != 0:\n        LOGGER.disabled = True\n    LOGGER.info(f\"device: {device} n_gpu: {n_gpu}, \"\n                f\"rank: {hvd.rank()}, 16-bits training: {cfg.fp16}\")\n\n    model = setup_model(cfg, device=device)\n    model.train()\n\n    optimizer = setup_e2e_optimizer(model, cfg)\n\n    # Horovod: (optional) compression algorithm.compressin\n    compression = hvd.Compression.none\n    optimizer = hvd.DistributedOptimizer(\n        optimizer, named_parameters=model.named_parameters(),\n        compression=compression)\n\n    #  Horovod: broadcast parameters & optimizer state.\n    compression = hvd.Compression.none\n    hvd.broadcast_parameters(model.state_dict(), root_rank=0)\n    hvd.broadcast_optimizer_state(optimizer, root_rank=0)\n\n    model, optimizer = amp.initialize(\n        model, optimizer, enabled=cfg.fp16, opt_level='O2',\n        keep_batchnorm_fp32=True)\n\n    # prepare data\n    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)\n    train_loaders, val_loaders = setup_dataloaders(cfg, tokenizer)\n    train_loader = MetaLoader(train_loaders,\n                              accum_steps=cfg.gradient_accumulation_steps,\n                              distributed=n_gpu > 1)\n    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)\n    train_loader = PrefetchLoader(train_loader, img_norm)\n    val_loaders = {k: PrefetchLoader(v, img_norm)\n                   for k, v in val_loaders.items()}\n\n    # compute the number of steps and update cfg\n    total_train_batch_size = int(\n        n_gpu * cfg.train_batch_size *\n        cfg.gradient_accumulation_steps * cfg.max_n_example_per_group)\n    total_n_epochs = cfg.num_train_epochs\n    cfg.num_train_steps = int(math.ceil(\n        1. * train_loader.n_batches_in_epoch * total_n_epochs /\n        (n_gpu * cfg.gradient_accumulation_steps)))\n    cfg.valid_steps = int(math.ceil(\n        1. * cfg.num_train_steps / cfg.num_valid /\n        cfg.min_valid_steps)) * cfg.min_valid_steps\n    actual_num_valid = int(math.floor(\n        1. * cfg.num_train_steps / cfg.valid_steps)) + 1\n\n    # restore\n    restorer = TrainingRestorer(cfg, model, optimizer)\n    global_step = restorer.global_step\n    TB_LOGGER.global_step = global_step\n    if hvd.rank() == 0:\n        LOGGER.info(\"Saving training meta...\")\n        save_training_meta(cfg)\n        LOGGER.info(\"Saving training done...\")\n        TB_LOGGER.create(join(cfg.output_dir, 'log'))\n        pbar = tqdm(total=cfg.num_train_steps)\n        model_saver = ModelSaver(join(cfg.output_dir, \"ckpt\"))\n        add_log_to_file(join(cfg.output_dir, \"log\", \"log.txt\"))\n    else:\n        LOGGER.disabled = True\n        pbar = NoOp()\n        model_saver = NoOp()\n        restorer = NoOp()\n\n    if global_step > 0:\n        pbar.update(global_step)\n\n    LOGGER.info(cfg)\n    LOGGER.info(\"Starting training...\")\n    LOGGER.info(f\"***** Running training with {n_gpu} GPUs *****\")\n    LOGGER.info(f\"  Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}\")\n    LOGGER.info(f\"  max_n_example_per_group = {cfg.max_n_example_per_group}\")\n    LOGGER.info(f\"  Accumulate steps = {cfg.gradient_accumulation_steps}\")\n    LOGGER.info(f\"  Total batch size = #GPUs * Single-GPU batch size * \"\n                f\"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}\")\n    LOGGER.info(f\"  Total #batches - single epoch = {train_loader.n_batches_in_epoch}.\")\n    LOGGER.info(f\"  Total #steps = {cfg.num_train_steps}\")\n    LOGGER.info(f\"  Total #epochs = {total_n_epochs}.\")\n    LOGGER.info(f\"  Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times\")\n\n\n    # quick hack for amp delay_unscale bug\n    with optimizer.skip_synchronize():\n        optimizer.zero_grad()\n        if global_step == 0:\n            optimizer.step()\n    debug_step = 5\n\n    tasks = []\n    for name, flag in zip([\"itc\"], [cfg.use_itc]):\n        if flag:\n            tasks.append(name)\n    task2loss = {t: RunningMeter(f'train_loss/{t}')\n                 for t in tasks}\n    task2loss[\"loss\"] = RunningMeter('train_loss/loss')\n\n    train_log = {'train/i2t_acc': 0,\n                 'train/t2i_acc': 0}\n\n    for step, (task, batch) in enumerate(train_loader):\n        # forward pass\n        outputs = forward_step(cfg, model, batch)\n        # mlm_loss, itm_loss, itc_loss, mpm_loss = 0, 0, 0, 0\n        itc_loss = 0\n\n        assert not cfg.use_mlm and not cfg.use_itm\n        \n        if cfg.use_itc:\n            n_itc_ex = len(outputs[\"itc_labels\"])\n            n_t2i_corrects = (\n                    outputs[\"t2i_scores\"].max(\n                        dim=-1)[1] == outputs[\"itc_labels\"]).sum().item()\n            n_i2t_corrects = (\n                    outputs[\"i2t_scores\"].max(\n                        dim=-1)[1] == outputs[\"itc_labels\"]).sum().item()\n\n            train_log.update({\n                'train/t2i_acc': float(n_t2i_corrects / n_itc_ex),\n                'train/i2t_acc': float(n_i2t_corrects / n_itc_ex),\n                # 'train/mpm_acc': mpm_acc\n            })\n\n            itc_loss = outputs[\"itc_loss\"]\n            task2loss[\"itc\"](itc_loss.item())\n\n        loss = itc_loss\n        task2loss[\"loss\"](loss.item())\n\n        delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0\n        with amp.scale_loss(\n                loss, optimizer, delay_unscale=delay_unscale\n                ) as scaled_loss:\n            scaled_loss.backward()\n            zero_none_grad(model)\n            optimizer.synchronize()\n\n        # optimizer\n        if (step + 1) % cfg.gradient_accumulation_steps == 0:\n            global_step += 1\n            if (step + 1) % cfg.log_interval == 0:\n                TB_LOGGER.log_scalar_dict({l.name: l.val\n                                        for l in task2loss.values()\n                                        if l.val is not None})\n            n_epoch = int(1. * n_gpu * cfg.gradient_accumulation_steps *\n                          global_step / train_loader.n_batches_in_epoch)\n\n            # learning rate scheduling for the whole model\n            lr_this_step = get_lr_sched(\n                global_step, cfg.decay, cfg.learning_rate,\n                cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio,\n                decay_epochs=cfg.step_decay_epochs,\n                multi_step_epoch=n_epoch)\n\n            # Hardcoded param group length\n            # assert len(optimizer.param_groups) == 8\n            for pg_n, param_group in enumerate(\n                    optimizer.param_groups):\n                    param_group['lr'] = lr_this_step\n\n            if (step + 1) % cfg.log_interval == 0:\n                TB_LOGGER.add_scalar(\n                    \"train/lr\", lr_this_step, global_step)\n\n            # update model params\n            if cfg.grad_norm != -1:\n                # import pdb; pdb.set_trace()\n                grad_norm = clip_grad_norm_(\n                    amp.master_params(optimizer), cfg.grad_norm)\n                if (step + 1) % cfg.log_interval == 0:\n                    TB_LOGGER.add_scalar(\"train/grad_norm\", grad_norm, global_step)\n            TB_LOGGER.step()\n\n            # Check if there is None grad\n            none_grads = [\n                p[0] for p in model.named_parameters()\n                if p[1].requires_grad and p[1].grad is None]\n\n            assert len(none_grads) == 0, f\"{none_grads}\"\n\n            with optimizer.skip_synchronize():\n                optimizer.step()\n                optimizer.zero_grad()\n            restorer.step()\n            pbar.update(1)\n\n            # validate and checkpoint\n            if global_step % cfg.valid_steps == 0:\n                LOGGER.info(f'Step {global_step}: start validation')\n                validate(model, val_loaders, cfg)\n                model_saver.save(step=global_step, model=model)\n        if global_step >= cfg.num_train_steps:\n            break\n\n        if cfg.debug and global_step >= debug_step:\n            break\n\n    if global_step % cfg.valid_steps != 0:\n        LOGGER.info(f'Step {global_step}: start validation')\n        validate(model, val_loaders, cfg)\n        model_saver.save(step=global_step, model=model)\n\n\nif __name__ == '__main__':\n    # Initialize Horovod\n    hvd.init()\n    start_training()\n"
  },
  {
    "path": "src/pretrain/run_pretrain_sparse.py",
    "content": "import os\n\nimport torch\nimport time\nimport random\nimport pprint\nimport math\nimport json\nfrom transformers import BertConfig, BertTokenizerFast\n\nfrom src.datasets.dataset_pretrain_sparse import AlproPretrainSparseDataset, PretrainImageTextDataset, PretrainCollator\nfrom src.datasets.dataloader import MetaLoader, PrefetchLoader\nfrom src.datasets.data_utils import ImageNorm, mk_input_group\nfrom torch.utils.data import DataLoader\nfrom torch.nn.utils import clip_grad_norm_\nfrom src.configs.config import shared_configs\nfrom src.utils.misc import set_random_seed, NoOp, zero_none_grad\nfrom src.utils.logger import LOGGER, TB_LOGGER, add_log_to_file, RunningMeter\nfrom src.utils.basic_utils import load_jsonl, load_json, read_dataframe\nfrom src.utils.load_save import (ModelSaver,\n                                 save_training_meta,\n                                 load_state_dict_with_pos_embed_resizing)\nfrom src.utils.load_save import E2E_TrainingRestorer as TrainingRestorer\nfrom src.optimization.sched import get_lr_sched\nfrom src.optimization.utils import setup_e2e_optimizer\nfrom collections import defaultdict\nfrom tqdm import tqdm\nfrom os.path import join\nfrom apex import amp\nfrom torch.utils.data.distributed import DistributedSampler\nimport horovod.torch as hvd\nfrom src.utils.distributed import all_gather_list\n\nfrom src.modeling.alpro_models import AlproForPretrain\n\n\ndef mk_captions_pretrain_dataloader(dataset_name, anno_path, video_dir, txt_dir, cfg, tokenizer, \n                                    is_train=True, max_txt_len=80):\n    # make a list(dict), where each dict {vis_id: int, txt: str}\n    if dataset_name == \"webvid2m\":\n        datalist = read_dataframe(anno_path)\n\n        datalist = datalist[datalist['txt_len'] < max_txt_len]\n        LOGGER.info('Found {} entries for webvid2m'.format(len(datalist)))\n    \n    elif dataset_name == \"cc3m\":\n        datalist = json.load(open(anno_path))\n        LOGGER.info('Found {} entries for cc3m'.format(len(datalist)))\n\n    else:\n        raise ValueError(\"Invalid dataset_name\")\n\n    if dataset_name in [\"webvid2m\"]:\n        frm_sampling_strategy = cfg.frm_sampling_strategy\n        if not is_train and frm_sampling_strategy == \"rand\":\n            frm_sampling_strategy = \"uniform\"\n        dataset = AlproPretrainSparseDataset(\n            datalist=datalist,\n            tokenizer=tokenizer,\n            img_lmdb_dir=video_dir,\n            img_db_type='rawvideo',\n            txt_dir=txt_dir,\n            crop_size=cfg.crop_img_size,\n            resize_size=cfg.resize_size,\n            max_txt_len=cfg.max_txt_len,\n            use_itm=cfg.use_itm,\n            fps=cfg.fps,\n            num_frm=cfg.num_frm,\n            frm_sampling_strategy=frm_sampling_strategy,\n            is_train=is_train\n            # vis_format=vis_format\n        )\n    elif dataset_name in [\"cc3m\"]:\n        dataset = PretrainImageTextDataset(datalist=datalist, \n                                           tokenizer=tokenizer,\n                                           crop_size=cfg.crop_img_size,\n                                           resize_size=cfg.resize_size,\n                                           max_txt_len=cfg.max_txt_len,\n                                           num_frm=cfg.num_frm\n                                           )\n\n    LOGGER.info(f\"[{dataset_name}] is_train {is_train} \"\n                f\"dataset size {len(dataset)}, \")\n    batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size\n    sampler = DistributedSampler(\n        dataset, num_replicas=hvd.size(), rank=hvd.rank(),\n        shuffle=is_train)\n    data_collator = PretrainCollator(tokenizer=tokenizer,\n                                    mlm=cfg.use_mlm,\n                                    mlm_probability=0.15,\n                                    max_length=cfg.max_txt_len,\n                                    mpm=cfg.use_mpm,\n                                    is_train=is_train)\n    dataloader = DataLoader(dataset,\n                            batch_size=batch_size,\n                            shuffle=False,\n                            sampler=sampler,\n                            num_workers=cfg.n_workers,\n                            pin_memory=cfg.pin_mem,\n                            collate_fn=data_collator.collate_batch)\n\n    return dataloader\n\n\ndef setup_dataloaders(cfg, tokenizer):\n    LOGGER.info(\"Init. train_loader and val_loader...\")\n\n    train_loaders = {}\n    for db in cfg.train_datasets:\n        train_loaders[db.name] = mk_captions_pretrain_dataloader(\n            dataset_name=db.name,\n            anno_path=db.ann, video_dir=db.img, txt_dir=db.txt,\n            cfg=cfg, tokenizer=tokenizer, is_train=True\n        )\n\n    val_loaders = {}\n    for db in cfg.val_datasets:\n        val_loaders[db.name] = mk_captions_pretrain_dataloader(\n            dataset_name=db.name,\n            anno_path=db.ann, video_dir=db.img, txt_dir=db.txt,\n            cfg=cfg, tokenizer=tokenizer, is_train=False\n        )\n    return train_loaders, val_loaders\n\n\ndef setup_model(cfg, device=None):\n    LOGGER.info(\"Setup model...\")\n    # has to be a BertConfig instance\n    model_cfg = load_json(cfg.model_config)\n    model_cfg = BertConfig(**model_cfg)\n    # add model-specific config\n    add_attr_list = [\n        \"max_n_example_per_group\",\n        \"num_entities\"\n    ]\n    for k in add_attr_list:\n        setattr(model_cfg, k, cfg[k])\n    LOGGER.info(f\"model_cfg {pprint.pformat(model_cfg.to_dict())}\")\n\n    LOGGER.info(\"setup e2e model\")\n\n    if cfg.model_type == 'pretrain':\n        # initialize cnn config\n        video_enc_cfg = load_json(cfg.visual_model_cfg)\n\n        video_enc_cfg['num_frm'] = cfg.num_frm\n        video_enc_cfg['img_size'] = cfg.crop_img_size\n\n        model = AlproForPretrain(\n            model_cfg, \n            input_format=cfg.img_input_format,\n            video_enc_cfg=video_enc_cfg\n            )\n        if cfg.e2e_weights_path:\n            LOGGER.info(f\"Loading e2e weights from {cfg.e2e_weights_path}\")\n            num_patches = (cfg.crop_img_size // video_enc_cfg['patch_size']) ** 2\n            # NOTE strict if False if loaded from ALBEF ckpt\n            load_state_dict_with_pos_embed_resizing(model, \n                                                    cfg.e2e_weights_path, \n                                                    num_patches=num_patches, \n                                                    num_frames=cfg.num_frm, \n                                                    strict=True\n                                                    )\n        else:\n            LOGGER.info(f\"Loading visual weights from {cfg.visual_weights_path}\")\n            model.load_separate_ckpt(\n                visual_weights_path=cfg.visual_weights_path,\n                prompter_weights_path=cfg.teacher_weights_path\n            )\n    else:\n        raise NotImplementedError(f\"cfg.model_type not found {cfg.model_type}.\")\n\n    # if cfg.freeze_cnn:\n    #     model.freeze_cnn_backbone()\n    \n    LOGGER.info(\"Moving model to device\") \n    model.to(device)\n    LOGGER.info(\"Completed moving model to device.\") \n\n    LOGGER.info(\"Setup model done!\")\n    return model\n\n\ndef forward_step(cfg, model, batch):\n    \"\"\"shared for training and validation\"\"\"\n    # used to make visual feature copies\n    if not cfg.use_itm:\n        batch[\"itm_labels\"] = None\n    outputs = model(batch)  # dict\n    return outputs\n\n\n@torch.no_grad()\ndef validate(model, val_loader, cfg):\n    model.eval()\n\n    mlm_loss = 0\n    n_mlm_tokens = 0\n    n_mlm_corrects = 0\n    itm_loss = 0\n    n_itm_ex = 0\n    n_itm_corrects = 0\n    itc_loss = 0\n    mpm_loss = 0\n    n_mpm_ex = 0\n    n_mpm_corrects = 0\n    st = time.time()\n    val_log = {'valid/mlm_loss': 0, 'valid/mlm_acc': 0,\n               'valid/itm_loss': 0, 'valid/itm_acc': 0,\n               'valid/mpm_loss': 0, 'valid/mpm_acc': 0,\n               'valid/itc_loss': 0}\n    debug_step = 5\n    val_loaders = val_loader if isinstance(val_loader, dict) else {\n        \"unnamed_val_loader\": val_loader}\n    \n    total_val_iters = 0 \n\n    LOGGER.info(f\"In total {len(val_loaders)} val loaders\")\n    for loader_name, val_loader in val_loaders.items():\n        LOGGER.info(f\"Loop val_loader {loader_name}.\")\n\n        total_val_iters += len(val_loader)\n        for val_step, batch in enumerate(val_loader):\n            # use iter to reset MetaLoader\n            # forward pass\n            outputs = forward_step(cfg, model, batch)\n\n            # mlm\n            mlm_labels = outputs[\"mlm_labels\"]\n            if cfg.use_mlm:\n                mlm_loss += outputs[\"mlm_loss\"].sum().item()\n                mlm_mask = mlm_labels != -100  # (B, Lt)  -100 is the ignored label for cross entropy\n                n_mlm_tokens += mlm_mask.sum().item()\n                if n_mlm_tokens > 0:\n                    n_mlm_corrects += (\n                            outputs[\"mlm_scores\"][mlm_mask].max(\n                                dim=-1)[1] == mlm_labels[mlm_mask]).sum().item()\n                else:\n                    n_mlm_corrects = 0\n\n            # itm\n            if cfg.use_itm:\n                itm_loss += outputs[\"itm_loss\"].sum().item() \n                n_itm_ex += len(outputs[\"itm_labels\"])\n                n_itm_corrects += (\n                        outputs[\"itm_scores\"].max(\n                            dim=-1)[1] == outputs[\"itm_labels\"]).sum().item()\n\n            if cfg.use_itc:\n                itc_loss += outputs[\"itc_loss\"].sum().item()\n\n            if cfg.use_mpm:\n                mpm_labels = outputs[\"mpm_labels\"]\n\n                if mpm_labels is not None:\n                    n_mpm_ex += len(mpm_labels) \n\n                    n_mpm_corrects += (\n                            outputs[\"mpm_logits\"].max(\n                                dim=-1)[1] == outputs[\"mpm_labels\"].max(dim=-1)[1]).sum().item()\n\n                    mpm_loss += outputs[\"mpm_loss\"].sum().item()\n\n            if cfg.debug and val_step >= debug_step:\n                break\n\n    # Gather across all processes\n    # mlm_loss = sum(all_gather_list(mlm_loss))\n    all_gather_mlm_loss = all_gather_list(mlm_loss)\n    mlm_loss = sum(all_gather_mlm_loss)\n    n_mlm_corrects = sum(all_gather_list(n_mlm_corrects))\n    n_mlm_tokens = sum(all_gather_list(n_mlm_tokens))\n\n    all_gather_itm_loss = all_gather_list(itm_loss)\n    itm_loss = sum(all_gather_itm_loss)\n    n_itm_corrects = sum(all_gather_list(n_itm_corrects))\n    n_itm_ex = sum(all_gather_list(n_itm_ex))\n\n    all_gather_itc_loss = all_gather_list(itc_loss)\n    itc_loss = sum(all_gather_itc_loss)\n\n    all_gather_mpm_loss = all_gather_list(mpm_loss)\n    mpm_loss = sum(all_gather_mpm_loss)\n    n_mpm_corrects = sum(all_gather_list(n_mpm_corrects))\n    n_mpm_ex = sum(all_gather_list(n_mpm_ex))\n\n    if n_mlm_tokens != 0:\n        val_log.update({\n            'valid/mlm_loss': float(mlm_loss),\n            'valid/mlm_acc': float(n_mlm_corrects / n_mlm_tokens)\n        })\n\n    # FIXME check this whether take mean?\n    if n_itm_ex != 0:\n        val_log.update({\n            'valid/itm_loss': float(itm_loss),\n            'valid/itm_acc': float(n_itm_corrects / n_itm_ex)\n        })\n    \n    # FIXME check this whether take mean?\n    if cfg.use_itc:\n        val_log.update({\n            'valid/itc_loss': float(itc_loss),\n        })\n\n    if n_mpm_ex != 0:\n        val_log.update({\n            'valid/mpm_loss': float(mpm_loss),\n            'valid/mpm_acc': float(n_mpm_corrects / n_mpm_ex)\n        })\n\n    TB_LOGGER.log_scalar_dict(val_log)\n    LOGGER.info(f\"validation finished in {int(time.time() - st)} seconds, \"\n                f\"[mlm_acc (per token)]: {val_log['valid/mlm_acc'] * 100:.2f} \"\n                f\"[mpm_acc (per token)]: {val_log['valid/mpm_acc'] * 100:.2f} \"\n                f\"[itm_acc (per example)]: {val_log['valid/itm_acc'] * 100:.2f} \")\n\n    LOGGER.info(\"[mlm_loss]: {} \".format(mlm_loss))\n    LOGGER.info(\"[itm_loss]: {} \".format(itm_loss))\n    LOGGER.info(\"[itc_loss]: {} \".format(itc_loss))\n    LOGGER.info(\"In total, {} validation iters.\".format(total_val_iters))\n\n    model.train()\n    return val_log\n\ndef get_video_prompt_templates():\n    prompts = [\n        'A footage of a {}.',\n        'A footage of the {}.',\n        'A footage of one {}.',\n        'A video of a {}.',\n        'A video of the {}.',\n        'A video of one {}.',\n        'A portrait of a {}.',\n        'A portrait of the {}.',\n        'A portrait of one {}.',\n        'A video footage of a {}.',\n        'A video footage of the {}.',\n        'A video footage of one {}.'\n    ]\n\n    return prompts\n\ndef get_image_prompt_templates():\n    prompts = [\n        # basics\n        'A photo of a {}.',\n        'A photo of the {}.',\n        'A photo of one {}.',\n        'A picture of a {}.',\n        'A picture of the {}.',\n        'A picture of one {}.',\n        # good photo/picture\n        'A good photo of the {}.',\n        'A good photo of a {}.',\n        'A good photo of one {}.',\n        'A good picture of the {}.',\n        'A good picture of a {}.',\n        'A good picture of one {}.'\n    ]\n\n    return prompts\n\n\ndef setup_text_prompts(cfg, tokenizer):\n    entity_filepath = cfg.entity_file_path\n    entity_num = cfg.num_entities\n\n    content = open(entity_filepath).read().split('\\n')[:entity_num]\n    entities = [c.split(' ')[0] for c in content]\n\n    video_prompt_templates = get_video_prompt_templates()\n    image_prompt_templates = get_image_prompt_templates()\n\n    video_prompts = []\n    for template in video_prompt_templates:\n        video_prompts.extend([template.format(e) for e in entities])\n    \n    image_prompts = []\n    for template in image_prompt_templates:\n        image_prompts.extend([template.format(e) for e in entities])\n\n    batch_enc_video_prompts = tokenizer.batch_encode_plus(\n        video_prompts,\n        max_length=15,\n        padding=\"max_length\",\n        return_tensors=\"pt\"\n    )\n\n    batch_enc_image_prompts = tokenizer.batch_encode_plus(\n        image_prompts,\n        max_length=15,\n        padding=\"max_length\",\n        return_tensors=\"pt\"\n    )\n\n    return dict(video_prompts=video_prompts, \n                image_prompts=image_prompts,\n                batch_enc_video_prompts=batch_enc_video_prompts,\n                batch_enc_image_prompts=batch_enc_image_prompts\n                )\n\n\ndef start_training():\n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n    cfg = shared_configs.get_sparse_pretraining_args()\n    set_random_seed(cfg.seed)\n\n    n_gpu = hvd.size()\n    # device = torch.device(\"cuda\", hvd.local_rank())\n    # torch.cuda.set_device(hvd.local_rank())\n\n    # This resolves the issue GPU 0 always has more processes running and more GPU-RAM.\n    # c.f. https://github.com/horovod/horovod/issues/2625#issuecomment-868134876\n    os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank())\n    device = torch.device(\"cuda\", 0)\n    torch.cuda.set_device(0)\n\n    if hvd.rank() != 0:\n        LOGGER.disabled = True\n    LOGGER.info(f\"device: {device} n_gpu: {n_gpu}, \"\n                f\"rank: {hvd.rank()}, 16-bits training: {cfg.fp16}\")\n\n    model = setup_model(cfg, device=device)\n    model.train()\n\n    optimizer = setup_e2e_optimizer(model, cfg)\n\n    # Horovod: (optional) compression algorithm.compressin\n    compression = hvd.Compression.none\n    optimizer = hvd.DistributedOptimizer(\n        optimizer, named_parameters=model.named_parameters(),\n        compression=compression)\n\n    #  Horovod: broadcast parameters & optimizer state.\n    compression = hvd.Compression.none\n    hvd.broadcast_parameters(model.state_dict(), root_rank=0)\n    hvd.broadcast_optimizer_state(optimizer, root_rank=0)\n\n    model, optimizer = amp.initialize(\n        model, optimizer, enabled=cfg.fp16, opt_level='O1')\n        # keep_batchnorm_fp32=True)\n\n    # prepare data\n    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)\n    train_loaders, val_loaders = setup_dataloaders(cfg, tokenizer)\n    train_loader = MetaLoader(train_loaders,\n                              accum_steps=cfg.gradient_accumulation_steps,\n                              distributed=n_gpu > 1)\n    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)\n    train_loader = PrefetchLoader(train_loader, img_norm)\n    val_loaders = {k: PrefetchLoader(v, img_norm)\n                   for k, v in val_loaders.items()}\n\n    # compute the number of steps and update cfg\n    total_train_batch_size = int(\n        n_gpu * cfg.train_batch_size *\n        cfg.gradient_accumulation_steps * cfg.max_n_example_per_group)\n    total_n_epochs = cfg.num_train_epochs\n    cfg.num_train_steps = int(math.ceil(\n        1. * train_loader.n_batches_in_epoch * total_n_epochs /\n        (n_gpu * cfg.gradient_accumulation_steps)))\n    cfg.valid_steps = int(math.ceil(\n        1. * cfg.num_train_steps / cfg.num_valid /\n        cfg.min_valid_steps)) * cfg.min_valid_steps\n    actual_num_valid = int(math.floor(\n        1. * cfg.num_train_steps / cfg.valid_steps)) + 1\n    \n    save_steps = int(cfg.save_steps_ratio * cfg.num_train_steps)\n\n    # restore\n    restorer = TrainingRestorer(cfg, model, optimizer)\n    global_step = restorer.global_step\n    TB_LOGGER.global_step = global_step\n    if hvd.rank() == 0:\n        LOGGER.info(\"Saving training meta...\")\n        save_training_meta(cfg)\n        LOGGER.info(\"Saving training done...\")\n        TB_LOGGER.create(join(cfg.output_dir, 'log'))\n        pbar = tqdm(total=cfg.num_train_steps)\n        model_saver = ModelSaver(join(cfg.output_dir, \"ckpt\"))\n        add_log_to_file(join(cfg.output_dir, \"log\", \"log.txt\"))\n    else:\n        LOGGER.disabled = True\n        pbar = NoOp()\n        model_saver = NoOp()\n        restorer = NoOp()\n\n    if global_step > 0:\n        pbar.update(global_step)\n\n    LOGGER.info(cfg)\n    LOGGER.info(\"Starting training...\")\n    LOGGER.info(f\"***** Running training with {n_gpu} GPUs *****\")\n    LOGGER.info(f\"  Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}\")\n    LOGGER.info(f\"  max_n_example_per_group = {cfg.max_n_example_per_group}\")\n    LOGGER.info(f\"  Accumulate steps = {cfg.gradient_accumulation_steps}\")\n    LOGGER.info(f\"  Total batch size = #GPUs * Single-GPU batch size * \"\n                f\"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}\")\n    LOGGER.info(f\"  Total #batches - single epoch = {train_loader.n_batches_in_epoch}.\")\n    LOGGER.info(f\"  Total #steps = {cfg.num_train_steps}\")\n    LOGGER.info(f\"  Total #epochs = {total_n_epochs}.\")\n    LOGGER.info(f\"  Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times\")\n\n\n    # quick hack for amp delay_unscale bug\n    with optimizer.skip_synchronize():\n        optimizer.zero_grad()\n        if global_step == 0:\n            optimizer.step()\n    debug_step = 20\n\n    tasks = []\n    for name, flag in zip([\"mlm\", \"itm\", \"itc\", \"mpm\"], [cfg.use_mlm, cfg.use_itm, cfg.use_itc, cfg.use_mpm]):\n        if flag:\n            tasks.append(name)\n    task2loss = {t: RunningMeter(f'train_loss/{t}')\n                 for t in tasks}\n    task2loss[\"loss\"] = RunningMeter('train_loss/loss')\n\n    train_log = {'train/mlm_acc': 0,\n                 'train/itm_acc': 0,\n                 'train/mpm_acc': 0,\n                 }\n\n    # create tokenized promopts\n    if not cfg.e2e_weights_path and cfg.use_mpm:\n        text_prompts = setup_text_prompts(cfg, tokenizer)\n        model.build_text_prompts(text_prompts)\n\n    for step, (task, batch) in enumerate(train_loader):\n        # forward pass\n        outputs = forward_step(cfg, model, batch)\n        mlm_loss, itm_loss, itc_loss, mpm_loss = 0, 0, 0, 0\n        # mlm_loss, itm_loss, itc_loss = 0, 0, 0\n        if cfg.use_mlm:\n            # mlm_loss = outputs[\"mlm_loss\"].mean()\n            mlm_loss = outputs[\"mlm_loss\"]\n            mlm_mask = outputs[\"mlm_labels\"] != -100\n            n_mlm_tokens = mlm_mask.sum().item()\n\n            task2loss[\"mlm\"](mlm_loss.item())\n\n        if cfg.use_itm:\n            itm_loss = outputs[\"itm_loss\"]\n            task2loss[\"itm\"](itm_loss.item())\n        \n        if cfg.use_itc:\n            itc_loss = outputs[\"itc_loss\"]\n            task2loss[\"itc\"](itc_loss.item())\n\n        if cfg.use_mpm:\n            mpm_loss = outputs[\"mpm_loss\"]\n            task2loss[\"mpm\"](mpm_loss.item())\n\n        loss = mlm_loss + itm_loss + itc_loss + mpm_loss\n        task2loss[\"loss\"](loss.item())\n\n        if step % cfg.log_interval == 0:\n            # training mlm token acc\n            if n_mlm_tokens > 0:\n                n_mlm_corrects = (\n                        outputs[\"mlm_scores\"][mlm_mask].max(\n                            dim=-1)[1] == outputs['mlm_labels'][mlm_mask]).sum().item()\n            else:\n                n_mlm_corrects = 0\n\n            # training itm acc\n            n_itm_ex = len(outputs[\"itm_labels\"])\n            n_itm_corrects = (\n                    outputs[\"itm_scores\"].max(\n                        dim=-1)[1] == outputs[\"itm_labels\"]).sum().item()\n\n            # training mpm acc\n            mpm_labels = outputs[\"mpm_labels\"]\n\n            if mpm_labels is not None:\n                n_mpm_ex = len(mpm_labels)\n                n_mpm_corrects = (\n                        outputs[\"mpm_logits\"].max(\n                            dim=-1)[1] == outputs[\"mpm_labels\"].max(dim=-1)[1]).sum().item()\n                mpm_acc = float(n_mpm_corrects / n_mpm_ex)\n            else:\n                mpm_acc = 0.\n\n            train_log.update({\n                'train/mlm_acc': float(n_mlm_corrects / n_mlm_tokens),\n                'train/itm_acc': float(n_itm_corrects / n_itm_ex),\n                'train/mpm_acc': mpm_acc\n            })\n\n            TB_LOGGER.log_scalar_dict(train_log)\n\n        delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0\n        with amp.scale_loss(\n                loss, optimizer, delay_unscale=delay_unscale\n                ) as scaled_loss:\n            scaled_loss.backward()\n            zero_none_grad(model)\n            optimizer.synchronize()\n\n        # optimizer\n        if (step + 1) % cfg.gradient_accumulation_steps == 0:\n            global_step += 1\n            if (step + 1) % cfg.log_interval == 0:\n                TB_LOGGER.log_scalar_dict({l.name: l.val\n                                        for l in task2loss.values()\n                                        if l.val is not None})\n            n_epoch = int(1. * n_gpu * cfg.gradient_accumulation_steps *\n                          global_step / train_loader.n_batches_in_epoch)\n\n            # learning rate scheduling for the whole model\n            lr_this_step = get_lr_sched(\n                global_step, cfg.decay, cfg.learning_rate,\n                cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio,\n                decay_epochs=cfg.step_decay_epochs,\n                multi_step_epoch=n_epoch)\n\n            # Hardcoded param group length\n            # assert len(optimizer.param_groups) == 8\n            for pg_n, param_group in enumerate(\n                    optimizer.param_groups):\n                    param_group['lr'] = lr_this_step\n\n            if (step + 1) % cfg.log_interval == 0:\n                TB_LOGGER.add_scalar(\n                    \"train/lr\", lr_this_step, global_step)\n\n            # update model params\n            if cfg.grad_norm != -1:\n                # import pdb; pdb.set_trace()\n                grad_norm = clip_grad_norm_(\n                    amp.master_params(optimizer), cfg.grad_norm)\n                if (step + 1) % cfg.log_interval == 0:\n                    TB_LOGGER.add_scalar(\"train/grad_norm\", grad_norm, global_step)\n            TB_LOGGER.step()\n\n            # Check if there is None grad\n            none_grads = [\n                p[0] for p in model.named_parameters()\n                if p[1].requires_grad and p[1].grad is None]\n\n            assert len(none_grads) == 0, f\"{none_grads}\"\n\n            with optimizer.skip_synchronize():\n                optimizer.step()\n                optimizer.zero_grad()\n            restorer.step()\n            pbar.update(1)\n\n            # validate and checkpoint\n            if global_step % cfg.valid_steps == 0:\n                LOGGER.info(f'Step {global_step}: start validation')\n                validate(model, val_loaders, cfg)\n                model_saver.save(step=global_step, model=model)\n            \n            if global_step % save_steps == 0:\n                LOGGER.info(f'Step {global_step}: saving model checkpoints.')\n                model_saver.save(step=global_step, model=model)\n\n        if global_step >= cfg.num_train_steps:\n            break\n\n        if cfg.debug and global_step >= debug_step:\n            break\n\n    if global_step % cfg.valid_steps != 0:\n        LOGGER.info(f'Step {global_step}: start validation')\n        validate(model, val_loaders, cfg)\n        model_saver.save(step=global_step, model=model)\n\n\nif __name__ == '__main__':\n    # Initialize Horovod\n    hvd.init()\n    start_training()\n"
  },
  {
    "path": "src/tasks/run_video_qa.py",
    "content": "import math\nimport os\nimport random\nimport time\nfrom collections import defaultdict\nfrom os.path import join\n\nimport horovod.torch as hvd\nimport torch\nfrom apex import amp\nfrom easydict import EasyDict as edict\nfrom src.configs.config import shared_configs\nfrom src.datasets.data_utils import ImageNorm, mk_input_group\nfrom src.datasets.dataloader import InfiniteIterator, PrefetchLoader\nfrom src.datasets.dataset_video_qa import (AlproVideoQADataset,\n                                           VideoQACollator)\nfrom src.modeling.alpro_models import AlproForSequenceClassification\nfrom src.optimization.sched import get_lr_sched\nfrom src.optimization.utils import setup_e2e_optimizer\nfrom src.utils.basic_utils import (get_rounded_percentage, load_json,\n                                   load_jsonl, save_json)\nfrom src.utils.distributed import all_gather_list\nfrom src.utils.load_save import E2E_TrainingRestorer as TrainingRestorer\nfrom src.utils.load_save import (ModelSaver,\n                                 load_state_dict_with_pos_embed_resizing,\n                                 save_training_meta)\nfrom src.utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file\nfrom src.utils.misc import NoOp, set_random_seed, zero_none_grad\nfrom torch.nn.utils import clip_grad_norm_\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nfrom tqdm import tqdm\nfrom transformers import BertConfig, BertTokenizerFast\n\n\ndef mk_qa_dataloader(task_type, anno_path, lmdb_dir, cfg, tokenizer,\n                          is_train=True, return_label=True):\n    \"\"\"\n    Returns:\n        list(dict), each dict is\n            msrvtt_qa: {\n                \"answer\": \"couch\",\n                \"question\": \"what are three people sitting on?\",\n                \"video_id\": \"video6513\",\n                \"answer_type\": \"what\"\n                }\n    \"\"\"\n    raw_datalist = load_jsonl(anno_path)\n    LOGGER.info(f\"Loaded data size {len(raw_datalist)}\")\n    if cfg.data_ratio != 1.0:\n        random.shuffle(raw_datalist)\n        raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]\n        LOGGER.info(f\"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}\")\n\n    datalist = []\n    qid = 0\n    for raw_d in raw_datalist:\n        d = dict(\n            question=raw_d[\"question\"],\n            vid_id=raw_d[\"video_id\"],\n            answer=raw_d[\"answer\"],  # int or str\n            question_id=qid  # be careful, it is not unique across splits\n        )\n        qid += 1\n\n        d[\"answer_type\"] = raw_d[\"answer_type\"]\n\n        datalist.append(d)\n\n    LOGGER.info(f\"datalist {len(datalist)}\")\n\n    grouped = defaultdict(list)  # examples grouped by image/video id\n    for d in datalist:\n        grouped[d[\"vid_id\"]].append(d)\n    LOGGER.info(f\"grouped {len(grouped)}\")\n\n    # each group has a single image with multiple questions\n    group_datalist = mk_input_group(\n        grouped,\n        max_n_example_per_group=cfg.max_n_example_per_group if is_train else 1,  # force 1 in eval,\n        is_train=is_train\n    )\n    LOGGER.info(f\"group_datalist {len(group_datalist)}\")\n\n    ans2label = load_json(cfg.ans2label_path)\n\n    frm_sampling_strategy = cfg.frm_sampling_strategy\n    if not is_train:\n        # frm_sampling_strategy = \"middle\"\n        frm_sampling_strategy = \"uniform\"\n    \n    if 'msvd' in cfg.task:\n        video_fmt = '.avi'\n    else:\n        video_fmt = '.mp4'\n\n    dataset = AlproVideoQADataset(\n        task_type=cfg.task,\n        datalist=group_datalist,\n        tokenizer=tokenizer,\n        img_lmdb_dir=lmdb_dir,\n        ans2label=ans2label,\n        max_img_size=cfg.crop_img_size,\n        max_txt_len=cfg.max_txt_len,\n        fps=cfg.fps,\n        num_frm=cfg.num_frm,\n        frm_sampling_strategy=frm_sampling_strategy,\n        ensemble_n_clips=cfg.train_n_clips if is_train else cfg.inference_n_clips,\n        return_label=return_label,\n        is_train=is_train,\n        img_db_type='rawvideo',\n        video_fmt=video_fmt\n    )\n    LOGGER.info(f\"is_train {is_train}, dataset size {len(dataset)} groups, \"\n                f\"each group {cfg.max_n_example_per_group if is_train else 1}\")\n    if cfg.do_inference:\n        batch_size = cfg.inference_batch_size\n    else:\n        batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size\n    sampler = DistributedSampler(\n        dataset, num_replicas=hvd.size(), rank=hvd.rank(),\n        shuffle=is_train)\n    vqa_collator = VideoQACollator(tokenizer=tokenizer,\n                                   max_length=cfg.max_txt_len,\n                                   task_type=cfg.task)\n    dataloader = DataLoader(dataset,\n                            batch_size=batch_size,\n                            shuffle=False,\n                            sampler=sampler,\n                            num_workers=cfg.n_workers,\n                            pin_memory=cfg.pin_mem,\n                            collate_fn=vqa_collator.collate_batch)\n    return dataloader\n\n\ndef setup_dataloaders(cfg, tokenizer):\n    LOGGER.info(\"Init. train_loader and val_loader...\")\n    train_loader = mk_qa_dataloader(\n        task_type=cfg.task,\n        anno_path=cfg.train_datasets[0].txt[cfg.task],\n        lmdb_dir=cfg.train_datasets[0].img,\n        cfg=cfg, tokenizer=tokenizer, is_train=True\n    )\n    val_loader = mk_qa_dataloader(\n        task_type=cfg.task,\n        anno_path=cfg.val_datasets[0].txt[cfg.task],\n        lmdb_dir=cfg.val_datasets[0].img,\n        cfg=cfg, tokenizer=tokenizer, is_train=False, return_label=False\n    )\n    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)\n    train_loader = PrefetchLoader(train_loader, img_norm)\n    val_loader = PrefetchLoader(val_loader, img_norm)\n    return train_loader, val_loader\n\n\ndef setup_model(cfg, device=None):\n    LOGGER.info(\"Setup model...\")\n    # has to be a BertConfig instance\n    model_cfg = load_json(cfg.model_config)\n    model_cfg = BertConfig(**model_cfg)\n    # add downstream model config\n    add_attr_list = [\n        \"num_labels\", \"classifier\", \"cls_hidden_scale\",\n        \"loss_type\",\n    ]\n    for k in add_attr_list:\n        setattr(model_cfg, k, cfg[k])\n    transformer_model_cls = AlproForSequenceClassification\n\n    # we separate the CNN and the transformer in order to use different optimizer for each\n    # transformer still has a CNN layer inside, used to down sample grid.\n    LOGGER.info(\"setup e2e model\")\n\n    video_enc_cfg = load_json(cfg.visual_model_cfg)\n\n    video_enc_cfg['num_frm'] = cfg.num_frm\n    video_enc_cfg['img_size'] = cfg.crop_img_size\n\n    model = AlproForSequenceClassification(\n        model_cfg, \n        input_format=cfg.img_input_format,\n        video_enc_cfg=video_enc_cfg\n        )\n\n    if cfg.e2e_weights_path:\n        LOGGER.info(f\"Loading e2e weights from {cfg.e2e_weights_path}\")\n        num_patches = (cfg.crop_img_size // video_enc_cfg['patch_size']) ** 2\n        # NOTE strict is False if loaded from ALBEF ckpt\n        load_state_dict_with_pos_embed_resizing(model, \n                                                cfg.e2e_weights_path, \n                                                num_patches=num_patches, \n                                                num_frames=cfg.num_frm, \n                                                strict=False,\n                                                remove_text_encoder_prefix=True\n                                                )\n        # LOGGER.info(f\"Loading e2e weights from {cfg.e2e_weights_path}\")\n        # load_state_dict_with_mismatch(model, cfg.e2e_weights_path)\n    else:\n        LOGGER.info(f\"Loading visual weights from {cfg.visual_weights_path}\")\n        LOGGER.info(f\"Loading bert weights from {cfg.bert_weights_path}\")\n        model.load_separate_ckpt(\n            visual_weights_path=cfg.visual_weights_path,\n            bert_weights_path=cfg.bert_weights_path\n        )\n\n    # if cfg.freeze_cnn:\n    #     model.freeze_cnn_backbone()\n    model.to(device)\n\n    LOGGER.info(\"Setup model done!\")\n    return model\n\n\ndef forward_step(model, batch, cfg):\n    \"\"\"shared for training and validation\"\"\"\n    if cfg.task in [\"action\", \"transition\"]:\n        repeat_counts = [e * cfg.num_labels for e in batch[\"n_examples_list\"]]\n        batch[\"n_examples_list\"] = repeat_counts\n\n    outputs = model(batch)  # dict\n    return outputs\n\n\n@torch.no_grad()\ndef validate(model, val_loader, cfg, train_global_step, eval_score=True):\n    \"\"\"use eval_score=False when doing inference on test sets where answers are not available\"\"\"\n    model.eval()\n\n    loss = 0.\n    n_ex = 0\n    qa_results = []\n    st = time.time()\n    debug_step = 5\n    pbar = tqdm(total=len(val_loader))\n    for val_step, batch in enumerate(val_loader):\n        # forward pass\n        question_ids = batch[\"question_ids\"]\n        bsz = len(question_ids)\n        # used to make visual feature copies\n        del batch[\"question_ids\"]\n        # add visual part into the mini batch and perform inference\n        mini_batch = dict()\n        for k, v in batch.items():\n            if k != \"visual_inputs\":\n                mini_batch[k] = v\n\n        n_ex += len(question_ids)\n        # multi-frame test, scores across frames of the same video will be pooled together\n        pool_method = cfg.score_agg_func\n        # could be 1, where only a single clip is evaluated\n        num_clips = cfg.inference_n_clips\n        num_frm = cfg.num_frm\n        # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)\n        new_visual_shape = (bsz, num_clips, num_frm) + batch[\"visual_inputs\"].shape[2:]\n        visual_inputs = batch[\"visual_inputs\"].view(*new_visual_shape)\n        logits = []\n        losses = []\n        for clip_idx in range(num_clips):\n            # (B, num_frm, C, H, W)\n            mini_batch[\"visual_inputs\"] = visual_inputs[:, clip_idx]\n            mini_batch[\"n_examples_list\"] = batch[\"n_examples_list\"]\n            outputs = forward_step(model, mini_batch, cfg)\n            logits.append(outputs[\"logits\"].cpu())\n            _loss = outputs[\"loss\"].sum().item() if isinstance(\n                outputs[\"loss\"], torch.Tensor) else 0\n            losses.append(_loss)\n        loss += (sum(losses) / num_clips)\n\n        logits = torch.stack(logits)  # (num_frm, B, 5)\n        if pool_method == \"mean\":\n            logits = logits.mean(0)  # (B, 5)\n        elif pool_method == \"max\":\n            logits = logits.max(0)[0]  # (B, 5)\n        elif pool_method == \"lse\":\n            logits = logits.permute(1, 0, 2).contiguous()  # (B, num_frm, 5), pooling will be done in CE\n            logits = torch.logsumexp(logits, dim=1)  # torch.exp alone might be too large and unstable\n        else:\n            raise ValueError(f\"Invalid value for pool_method, \"\n                             f\"got {pool_method}, expect one of [`mean`, `max`, `lse`]\")\n\n        if cfg.task in [\"action\", \"transition\", \"frameqa\", \"msrvtt_qa\", \"msvd_qa\"]:\n            # cross entropy\n            pred_labels = logits.max(dim=-1)[1].data.tolist()\n        else:\n            # mse\n            preds = (logits + 0.5).long().clamp(min=1, max=10)\n            pred_labels = preds.data.squeeze().tolist()\n        for qid, pred_label in zip(question_ids, pred_labels):\n            qa_results.append(dict(\n                question_id=qid,\n                answer=pred_label,\n                data=val_loader.dataset.qid2data[qid]\n            ))\n        pbar.update(1)\n        if cfg.debug and val_step >= debug_step:\n            break\n\n    if cfg.debug:\n        LOGGER.info(qa_results[:10])\n    n_ex_per_rank = all_gather_list(n_ex)\n    loss = sum(all_gather_list(loss))\n    n_ex = sum(all_gather_list(n_ex))\n    # average loss for each example\n    val_log = {f'valid/loss': float(loss / n_ex)}\n    if eval_score:\n        LOGGER.info(f\"QA Task [{cfg.task}], \"\n                    f\"{len(qa_results)} qa_results,\"\n                    f\"3 examples here: {qa_results[:3]}\")\n        vqa_scores = val_loader.dataset.evaluate_qa(qa_results)\n        # print(f\"{hvd.rank()}: {vqa_scores}\")\n\n        # Gather scores\n        scores_per_rank = all_gather_list(vqa_scores)\n        gathered_scores = {}\n        if \"ratios\" in scores_per_rank[0]:\n            gathered_ratios = {\n                k: [0, 0] for k, _ in scores_per_rank[0][\"ratios\"].items()}\n            # Gather ratios\n            for rank_id in range(len(n_ex_per_rank)):\n                current_ratios = scores_per_rank[rank_id][\"ratios\"]\n                for k, v in current_ratios.items():\n                    gathered_ratios[k][1] += v[1]\n            for k, v in gathered_ratios.items():\n                gathered_ratios[k][0] = get_rounded_percentage(\n                    1. * v[1] / n_ex)\n            gathered_scores[\"ratios\"] = gathered_ratios\n\n        # FIXME: Gather scores become complicated due to np.mean and dict format.\n        for scores_k, _ in vqa_scores.items():\n            if \"ratio\" in scores_k:\n                continue\n            gathered_v = 0\n            for rank_id, n in enumerate(n_ex_per_rank):\n                curr_acc, curr_n_ex = 0, 0\n                if \"overall\" in scores_k:\n                    curr_acc = scores_per_rank[rank_id][scores_k] * n\n                else:\n                    if \"ratios\" in scores_per_rank[0]:\n                        curr_n_ex = scores_per_rank[\n                                rank_id][\"ratios\"][\n                                    scores_k.replace(\"acc\", \"ratio\")][1]\n                        curr_acc = scores_per_rank[rank_id][\n                            scores_k] * curr_n_ex\n                gathered_v += curr_acc\n            if \"overall\" in scores_k:\n                gathered_v = gathered_v * 1. / n_ex\n            else:\n                if \"ratios\" in scores_per_rank[0]:\n                    _num = gathered_ratios[\n                        scores_k.replace(\"acc\", \"ratio\")][1]\n                    gathered_v = gathered_v * 1. / _num if _num != 0 else 0\n            if cfg.task in [\"action\", \"transition\", \"frameqa\", \"msrvtt_qa\", \"msvd_qa\"]:\n                gathered_scores[scores_k] = get_rounded_percentage(\n                    gathered_v)\n            else:\n                gathered_scores[scores_k] = round(gathered_v, 2)\n\n        for k, v in gathered_scores.items():\n            if \"ratio\" not in k:\n                val_log[f'valid/{k}'] = v\n    else:\n        LOGGER.info(\"eval_score = False, no scores are calculated.\")\n        gathered_scores = 0\n\n    TB_LOGGER.log_scalar_dict(val_log)\n    LOGGER.info(f\"validation finished in {int(time.time() - st)} seconds.\"\n                f\"{gathered_scores}\")\n\n    model.train()\n    return qa_results, gathered_scores\n\n\ndef start_training(cfg):\n    set_random_seed(cfg.seed)\n\n    n_gpu = hvd.size()\n    cfg.n_gpu = n_gpu\n    device = torch.device(\"cuda\", hvd.local_rank())\n    torch.cuda.set_device(hvd.local_rank())\n    if hvd.rank() != 0:\n        LOGGER.disabled = True\n    LOGGER.info(\"device: {} n_gpu: {}, rank: {}, \"\n                \"16-bits training: {}\".format(\n                    device, n_gpu, hvd.rank(), bool(cfg.fp16)))\n\n    model = setup_model(cfg, device=device)\n    model.train()\n    optimizer = setup_e2e_optimizer(model, cfg)\n\n    # Horovod: (optional) compression algorithm.compressin\n    compression = hvd.Compression.none\n    optimizer = hvd.DistributedOptimizer(\n        optimizer, named_parameters=model.named_parameters(),\n        compression=compression)\n\n    #  Horovod: broadcast parameters & optimizer state.\n    hvd.broadcast_parameters(model.state_dict(), root_rank=0)\n    hvd.broadcast_optimizer_state(optimizer, root_rank=0)\n\n    model, optimizer = amp.initialize(\n        model, optimizer, enabled=cfg.fp16, opt_level='O2',\n        keep_batchnorm_fp32=True)\n\n    # prepare data\n    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)\n    train_loader, val_loader = setup_dataloaders(cfg, tokenizer)\n\n    # compute the number of steps and update cfg\n    total_n_examples = len(train_loader.dataset) * cfg.max_n_example_per_group\n    total_train_batch_size = int(\n        n_gpu * cfg.train_batch_size *\n        cfg.gradient_accumulation_steps * cfg.max_n_example_per_group)\n    cfg.num_train_steps = int(math.ceil(\n        1. * cfg.num_train_epochs * total_n_examples / total_train_batch_size))\n    cfg.valid_steps = int(math.ceil(\n        1. * cfg.num_train_steps / cfg.num_valid /\n        cfg.min_valid_steps)) * cfg.min_valid_steps\n    actual_num_valid = int(math.floor(\n        1. * cfg.num_train_steps / cfg.valid_steps)) + 1\n\n    # restore\n    restorer = TrainingRestorer(cfg, model, optimizer)\n    global_step = restorer.global_step\n    TB_LOGGER.global_step = global_step\n    if hvd.rank() == 0:\n        LOGGER.info(\"Saving training meta...\")\n        save_training_meta(cfg)\n        LOGGER.info(\"Saving training done...\")\n        TB_LOGGER.create(join(cfg.output_dir, 'log'))\n        pbar = tqdm(total=cfg.num_train_steps)\n        model_saver = ModelSaver(join(cfg.output_dir, \"ckpt\"))\n        add_log_to_file(join(cfg.output_dir, \"log\", \"log.txt\"))\n    else:\n        LOGGER.disabled = True\n        pbar = NoOp()\n        model_saver = NoOp()\n        restorer = NoOp()\n\n    if global_step > 0:\n        pbar.update(global_step)\n\n    LOGGER.info(cfg)\n    LOGGER.info(\"Starting training...\")\n    LOGGER.info(f\"***** Running training with {n_gpu} GPUs *****\")\n    LOGGER.info(f\"  Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}\")\n    LOGGER.info(f\"  max_n_example_per_group = {cfg.max_n_example_per_group}\")\n    LOGGER.info(f\"  Accumulate steps = {cfg.gradient_accumulation_steps}\")\n    LOGGER.info(f\"  Total batch size = #GPUs * Single-GPU batch size * \"\n                f\"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}\")\n    LOGGER.info(f\"  Total #epochs = {cfg.num_train_epochs}\")\n    LOGGER.info(f\"  Total #steps = {cfg.num_train_steps}\")\n    LOGGER.info(f\"  Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times\")\n\n    # quick hack for amp delay_unscale bug\n    with optimizer.skip_synchronize():\n        optimizer.zero_grad()\n        if global_step == 0:\n            optimizer.step()\n    debug_step = 3\n    running_loss = RunningMeter('train_loss')\n    for step, batch in enumerate(InfiniteIterator(train_loader)):\n        # forward pass\n        bsz = len(batch[\"question_ids\"])\n        del batch[\"question_ids\"]\n        mini_batch = dict()\n        for k, v in batch.items():\n            if k != \"visual_inputs\":\n                mini_batch[k] = v\n\n        pool_method = cfg.score_agg_func\n        # could be 1, where only a single clip is used\n        num_clips = cfg.train_n_clips\n        num_frm = cfg.num_frm\n        # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)\n        new_visual_shape = (bsz, num_clips, num_frm) + batch[\"visual_inputs\"].shape[2:]\n        visual_inputs = batch[\"visual_inputs\"].view(*new_visual_shape)\n        logits = []\n        for clip_idx in range(num_clips):\n            # (B, num_frm, C, H, W)\n            mini_batch[\"visual_inputs\"] = visual_inputs[:, clip_idx]\n            mini_batch[\"n_examples_list\"] = batch[\"n_examples_list\"]\n            # outputs = forward_step(model, mini_batch, cfg)\n            outputs = forward_step(model, mini_batch, cfg)\n            logits.append(outputs)\n            # the losses are cross entropy and mse, no need to * num_labels\n\n            loss = outputs['loss']\n\n        loss = loss.mean()\n\n        running_loss(loss.item())\n        # backward pass\n        delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0\n        with amp.scale_loss(\n                loss, optimizer, delay_unscale=delay_unscale\n                ) as scaled_loss:\n            scaled_loss.backward()\n            zero_none_grad(model)\n            optimizer.synchronize()\n\n        # optimizer\n        if (step + 1) % cfg.gradient_accumulation_steps == 0:\n            global_step += 1\n\n            # learning rate scheduling\n            n_epoch = int(1. * total_train_batch_size * global_step\n                          / total_n_examples)\n            # learning rate scheduling cnn\n            lr_this_step = get_lr_sched(\n                global_step, cfg.decay, cfg.learning_rate,\n                cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio,\n                decay_epochs=cfg.step_decay_epochs,\n                multi_step_epoch=n_epoch)\n\n            # Hardcoded param group length\n            for pg_n, param_group in enumerate(\n                    optimizer.param_groups):\n                    param_group['lr'] = lr_this_step\n\n            if step % cfg.log_interval == 0:\n                TB_LOGGER.add_scalar(\n                    \"train/lr\", lr_this_step, global_step)\n\n            TB_LOGGER.add_scalar('train/loss', running_loss.val, global_step)\n\n            # update model params\n            if cfg.grad_norm != -1:\n                grad_norm = clip_grad_norm_(\n                    amp.master_params(optimizer),\n                    cfg.grad_norm)\n                TB_LOGGER.add_scalar(\n                    \"train/grad_norm\", grad_norm, global_step)\n            TB_LOGGER.step()\n\n            # Check if there is None grad\n            none_grads = [\n                p[0] for p in model.named_parameters()\n                if p[1].requires_grad and p[1].grad is None]\n\n            assert len(none_grads) == 0, f\"{none_grads}\"\n\n            with optimizer.skip_synchronize():\n                optimizer.step()\n                optimizer.zero_grad()\n            restorer.step()\n            pbar.update(1)\n\n            # checkpoint\n            if global_step % cfg.valid_steps == 0:\n                LOGGER.info(f'Step {global_step}: start validation')\n                validate(\n                    model, val_loader, cfg, global_step)\n                model_saver.save(step=global_step, model=model)\n        if global_step >= cfg.num_train_steps:\n            break\n\n        if cfg.debug and global_step >= debug_step:\n            break\n\n    if global_step % cfg.valid_steps != 0:\n        LOGGER.info(f'Step {global_step}: start validation')\n        qa_results, qa_scores = validate(\n            model, val_loader, cfg, global_step)\n        model_saver.save(step=global_step, model=model)\n\n\ndef start_inference(cfg):\n    set_random_seed(cfg.seed)\n    n_gpu = hvd.size()\n    device = torch.device(\"cuda\", hvd.local_rank())\n    torch.cuda.set_device(hvd.local_rank())\n    if hvd.rank() != 0:\n        LOGGER.disabled = True\n\n    inference_res_dir = join(\n        cfg.output_dir,\n        f\"results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/\"\n        f\"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}\"\n    )\n\n    if hvd.rank() == 0:\n        os.makedirs(inference_res_dir, exist_ok=True)\n        save_json(cfg, join(inference_res_dir, \"raw_args.json\"),\n                  save_pretty=True)\n\n    LOGGER.info(\"device: {} n_gpu: {}, rank: {}, \"\n                \"16-bits training: {}\".format(\n                    device, n_gpu, hvd.rank(), bool(cfg.fp16)))\n\n    # overwrite cfg with stored_cfg,\n    # but skip keys containing the keyword 'inference'\n    stored_cfg_path = join(cfg.output_dir, \"log/args.json\")\n    stored_cfg = edict(load_json(stored_cfg_path))\n    for k, v in cfg.items():\n        if k in stored_cfg and \"inference\" not in k:\n            setattr(cfg, k, stored_cfg[k])\n\n    # setup models\n    cfg.model_config = join(cfg.output_dir, \"log/model_config.json\")\n    e2e_weights_path = join(\n        cfg.output_dir, f\"ckpt/model_step_{cfg.inference_model_step}.pt\")\n    cfg.e2e_weights_path = e2e_weights_path\n    model = setup_model(cfg, device=device)\n    model.eval()\n\n    # FIXME separate scaling for each loss\n    model = amp.initialize(\n        model, enabled=cfg.fp16, opt_level='O2')\n\n    global_step = 0\n    # prepare data\n    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)\n    cfg.data_ratio = 1.\n    val_loader = mk_qa_dataloader(\n        task_type=cfg.task,\n        anno_path=cfg.inference_txt_db,\n        lmdb_dir=cfg.inference_img_db,\n        cfg=cfg, tokenizer=tokenizer,\n        is_train=False,\n        return_label=False\n    )\n    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)\n    val_loader = PrefetchLoader(val_loader, img_norm)\n\n    LOGGER.info(cfg)\n    LOGGER.info(\"Starting inference...\")\n    LOGGER.info(f\"***** Running inference with {n_gpu} GPUs *****\")\n    LOGGER.info(f\"  Batch size = {cfg.inference_batch_size}\")\n\n    LOGGER.info(f'Step {global_step}: start validation')\n    qa_results, qa_scores = validate(\n        model, val_loader, cfg, global_step,\n        eval_score=True)  # cfg.inference_split == \"val\"\n\n    if hvd.rank() == 0:\n        save_json(cfg, join(inference_res_dir, \"merged_args.json\"),\n                  save_pretty=True)\n        save_json(qa_scores, join(inference_res_dir, \"scores.json\"),\n                  save_pretty=True)\n\n    # ###### Saving with Horovod ####################\n    # dummy sync\n    _ = None\n    all_gather_list(_)\n    if n_gpu > 1:\n        # with retrial, as azure blob fails occasionally.\n        max_save_load_trial = 10\n        save_trial = 0\n        while save_trial < max_save_load_trial:\n            try:\n                LOGGER.info(f\"Save results trial NO. {save_trial}\")\n                save_json(\n                    qa_results,\n                    join(inference_res_dir, f\"results_rank{hvd.rank()}.json\"))\n                break\n            except Exception as e:\n                save_trial += 1\n    # dummy sync\n    _ = None\n    all_gather_list(_)\n    # join results\n    if n_gpu > 1 and hvd.rank() == 0:\n        qa_results = []\n        for rk in range(n_gpu):\n            qa_results.extend(load_json(\n                join(inference_res_dir, f\"results_rank{rk}.json\")))\n        LOGGER.info(f'results joined')\n\n    if hvd.rank() == 0:\n        save_json(\n            qa_results,\n            join(inference_res_dir, f\"results_all.json\"))\n        LOGGER.info(f'all results written')\n\n\nif __name__ == '__main__':\n    # Initialize Horovod\n    hvd.init()\n    input_cfg = shared_configs.get_video_qa_args()\n    if input_cfg.do_inference:\n        # assert hvd.size() == 1, \\\n        #     \"Please use single GPU for evaluation! \" \\\n        #     \"Multi-GPU might miss some examples.\"\n        start_inference(input_cfg)\n    else:\n        start_training(input_cfg)\n"
  },
  {
    "path": "src/tasks/run_video_retrieval.py",
    "content": "import json\nimport math\nimport os\nimport random\nimport time\nfrom collections import defaultdict\nfrom os.path import exists, join\n\nimport horovod.torch as hvd\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom apex import amp\nfrom easydict import EasyDict as edict\nfrom src.configs.config import shared_configs\nfrom src.datasets.data_utils import ImageNorm, mk_input_group\nfrom src.datasets.dataloader import InfiniteIterator, PrefetchLoader\nfrom src.datasets.dataset_video_retrieval import (\n    AlproVideoRetrievalDataset, AlproVideoRetrievalEvalDataset,\n    VideoRetrievalCollator)\nfrom src.modeling.alpro_models import AlproForVideoTextRetrieval\nfrom src.optimization.sched import get_lr_sched\nfrom src.optimization.utils import setup_e2e_optimizer\nfrom src.utils.basic_utils import (get_rounded_percentage, load_json,\n                                   load_jsonl, save_json)\nfrom src.utils.distributed import all_gather_list\nfrom src.utils.load_save import E2E_TrainingRestorer as TrainingRestorer\nfrom src.utils.load_save import (ModelSaver,\n                                 load_state_dict_with_pos_embed_resizing,\n                                 save_training_meta)\nfrom src.utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file\nfrom src.utils.misc import NoOp, set_random_seed, zero_none_grad\nfrom torch.nn.utils import clip_grad_norm_\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data.distributed import DistributedSampler\nfrom tqdm import tqdm\nfrom transformers import BertConfig, BertTokenizerFast\n\n\ndef mk_video_ret_datalist(raw_datalist, cfg):\n    \"\"\"\n    Args:\n        raw_datalist: list(dict)\n        Each data point is {id: int, txt: str, vid_id: str}\n\n    Returns:\n\n    \"\"\"\n    LOGGER.info(f\"Loaded data size {len(raw_datalist)}\")\n    if cfg.data_ratio != 1.0:\n        random.shuffle(raw_datalist)\n        raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]\n        LOGGER.info(f\"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}\")\n\n    datalist = []\n    qid = 0\n    for raw_d in raw_datalist:\n        d = dict(\n            id=qid,\n            txt=raw_d[\"caption\"],\n            vid_id=raw_d[\"clip_name\"]\n        )\n        qid += 1\n        datalist.append(d)\n    LOGGER.info(f\"datalist {len(datalist)}\")\n    return datalist\n\n\ndef mk_video_ret_dataloader(anno_path, lmdb_dir, cfg, tokenizer, is_train=True):\n    \"\"\"\"\"\"\n    raw_datalist = load_jsonl(anno_path)\n    datalist = mk_video_ret_datalist(raw_datalist, cfg)\n    grouped = defaultdict(list)  # examples grouped by image/video id\n    for d in datalist:\n        grouped[d[\"vid_id\"]].append(d)\n    LOGGER.info(f\"grouped {len(grouped)}\")\n\n    # each group has a single image with multiple questions\n    group_datalist = mk_input_group(\n        grouped,\n        max_n_example_per_group=cfg.max_n_example_per_group if is_train else 1,  # force 1 in eval,\n        is_train=is_train\n    )\n    LOGGER.info(f\"group_datalist {len(group_datalist)}\")\n\n    frm_sampling_strategy = cfg.frm_sampling_strategy\n    if not is_train and frm_sampling_strategy == \"rand\":\n        frm_sampling_strategy = \"uniform\"\n    \n    if 'msvd' in cfg.train_datasets[0]['name']:\n        video_fmt = '.avi'\n    else:\n        video_fmt = '.mp4'\n\n    dataset = AlproVideoRetrievalDataset(\n        datalist=group_datalist,\n        tokenizer=tokenizer,\n        img_lmdb_dir=lmdb_dir,\n        max_img_size=cfg.crop_img_size,\n        max_txt_len=cfg.max_txt_len,\n        fps=cfg.fps,\n        num_frm=cfg.num_frm,\n        frm_sampling_strategy=frm_sampling_strategy,\n        itm_neg_size=0,\n        is_train=is_train,\n        img_db_type='rawvideo',\n        video_fmt=video_fmt\n    )\n    LOGGER.info(f\"is_train {is_train}, dataset size {len(dataset)} groups, \"\n                f\"each group {cfg.max_n_example_per_group if is_train else 1}\")\n    if cfg.do_inference:\n        batch_size = cfg.inference_batch_size\n    else:\n        batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size\n    sampler = DistributedSampler(\n        dataset, num_replicas=hvd.size(), rank=hvd.rank(),\n        shuffle=is_train)\n    vqa_collator = VideoRetrievalCollator(\n        tokenizer=tokenizer, max_length=cfg.max_txt_len)\n    dataloader = DataLoader(dataset,\n                            batch_size=batch_size,\n                            shuffle=False,\n                            sampler=sampler,\n                            num_workers=cfg.n_workers,\n                            pin_memory=cfg.pin_mem,\n                            collate_fn=vqa_collator.collate_batch)\n    return dataloader\n\n\ndef mk_video_ret_eval_dataloader(anno_path, lmdb_dir, cfg, tokenizer):\n    \"\"\"\n    eval_retrieval: bool, will sample one video per batch paired with multiple text.\n    Returns:\n\n    \"\"\"\n    raw_datalist = load_jsonl(anno_path)\n    datalist = mk_video_ret_datalist(raw_datalist, cfg)\n    frm_sampling_strategy = cfg.frm_sampling_strategy\n    if frm_sampling_strategy == \"rand\":\n        frm_sampling_strategy = \"uniform\"\n\n    if 'msvd' in cfg.train_datasets[0]['name']:\n        video_fmt = '.avi'\n    else:\n        video_fmt = '.mp4'\n\n    dataset = AlproVideoRetrievalEvalDataset(\n        datalist=datalist,\n        tokenizer=tokenizer,\n        img_lmdb_dir=lmdb_dir,\n        max_img_size=cfg.crop_img_size,\n        max_txt_len=cfg.max_txt_len,\n        fps=cfg.fps,\n        num_frm=cfg.num_frm,\n        frm_sampling_strategy=frm_sampling_strategy,\n        video_fmt=video_fmt,\n        img_db_type='rawvideo'\n    )\n    sampler = DistributedSampler(\n        dataset, num_replicas=hvd.size(), rank=hvd.rank(),\n        shuffle=False)\n    retrieval_collator = VideoRetrievalCollator(\n        tokenizer=tokenizer, max_length=cfg.max_txt_len)\n    dataloader = DataLoader(dataset,\n                            batch_size=1,  # already batched in dataset\n                            shuffle=False,\n                            sampler=sampler,\n                            num_workers=cfg.n_workers,\n                            pin_memory=cfg.pin_mem,\n                            collate_fn=retrieval_collator.collate_batch)\n    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)\n    dataloader = PrefetchLoader(dataloader, img_norm)\n    return dataloader\n\n\ndef setup_dataloaders(cfg, tokenizer):\n    LOGGER.info(\"Init. train_loader and val_loader...\")\n    train_loader = mk_video_ret_dataloader(\n        anno_path=cfg.train_datasets[0].txt,\n        lmdb_dir=cfg.train_datasets[0].img,\n        cfg=cfg, tokenizer=tokenizer, is_train=True\n    )\n    val_loader = mk_video_ret_dataloader(\n        anno_path=cfg.val_datasets[0].txt,\n        lmdb_dir=cfg.val_datasets[0].img,\n        cfg=cfg, tokenizer=tokenizer, is_train=False\n    )\n    img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)\n    train_loader = PrefetchLoader(train_loader, img_norm)\n    val_loader = PrefetchLoader(val_loader, img_norm)\n    return train_loader, val_loader\n\n\ndef setup_model(cfg, device=None):\n    LOGGER.info(\"Setup model...\")\n    # has to be a BertConfig instance\n    model_cfg = load_json(cfg.model_config)\n    model_cfg = BertConfig(**model_cfg)\n    # add downstream model config\n    add_attr_list = []\n    for k in add_attr_list:\n        setattr(model_cfg, k, cfg[k])\n\n    # we separate the CNN and the transformer in order to use different optimizer for each\n    # transformer still has a CNN layer inside, used to down sample grid.\n    LOGGER.info(\"setup e2e model\")\n\n    video_enc_cfg = load_json(cfg.visual_model_cfg)\n\n    video_enc_cfg['num_frm'] = cfg.num_frm\n    video_enc_cfg['img_size'] = cfg.crop_img_size\n\n    model = AlproForVideoTextRetrieval(\n        model_cfg, \n        input_format=cfg.img_input_format,\n        video_enc_cfg=video_enc_cfg\n        )\n    if cfg.e2e_weights_path:\n        LOGGER.info(f\"Loading e2e weights from {cfg.e2e_weights_path}\")\n        num_patches = (cfg.crop_img_size // video_enc_cfg['patch_size']) ** 2\n        # NOTE strict if False if loaded from ALBEF ckpt\n        load_state_dict_with_pos_embed_resizing(model, \n                                                cfg.e2e_weights_path, \n                                                num_patches=num_patches, \n                                                num_frames=cfg.num_frm, \n                                                strict=False,\n                                                )\n    else:\n        LOGGER.info(f\"Loading visual weights from {cfg.visual_weights_path}\")\n        LOGGER.info(f\"Loading bert weights from {cfg.bert_weights_path}\")\n        model.load_separate_ckpt(\n            visual_weights_path=cfg.visual_weights_path,\n            bert_weights_path=cfg.bert_weights_path\n        )\n\n    # if cfg.freeze_cnn:\n    #     model.freeze_cnn_backbone()\n    model.to(device)\n\n    LOGGER.info(\"Setup model done!\")\n    return model\n\n\ndef forward_step(model, batch):\n    \"\"\"shared for training and validation\"\"\"\n    outputs = model(batch)  # dict\n    return outputs\n\ndef forward_inference_step(model, batch):\n    outputs = model.forward_inference(batch)\n    return outputs\n\n@torch.no_grad()\ndef validate(model, val_loader, eval_loader, cfg, train_global_step, eval_filepath):\n    \"\"\"use eval_score=False when doing inference on test sets where answers are not available\"\"\"\n    model.eval()\n\n    loss = 0.\n    n_ex = 0\n    n_corrects = 0\n    st = time.time()\n    debug_step = 10\n    for val_step, batch in enumerate(val_loader):\n        # forward pass\n        del batch[\"caption_ids\"]\n        outputs = forward_step(model, batch)\n        targets = batch['labels']\n\n        batch_loss = outputs['itm_loss'] + outputs['itc_loss']\n\n        if isinstance(batch_loss, torch.Tensor):\n            loss += batch_loss.sum().item()\n        else:\n            raise NotImplementedError('Expecting loss as Tensor, found: {}'.format(type(loss)))\n\n        # n_ex += len(targets)\n        n_ex += len(targets)\n\n        if cfg.debug and val_step >= debug_step:\n            break\n\n    loss = sum(all_gather_list(loss))\n    n_ex = sum(all_gather_list(n_ex))\n    n_corrects = sum(all_gather_list(n_corrects))\n\n    _, retrieval_metrics = inference_retrieval(model, eval_loader, eval_filepath, cfg)\n\n    model.train()\n\n    if hvd.rank() == 0:\n        # average loss for each example\n        acc = float(n_corrects / n_ex)\n        val_log = {'valid/loss': float(loss / n_ex), 'valid/acc': acc}\n        for ret_type, ret_m in retrieval_metrics.items():\n            val_log.update({f\"valid/{ret_type}_{k}\": round(v, 4) for k, v in ret_m.items()})\n\n        TB_LOGGER.log_scalar_dict(val_log)\n        LOGGER.info(f\"validation finished in {int(time.time() - st)} seconds.\"\n                    f\"itm_acc: {acc}. Retrieval res {retrieval_metrics}\")\n\n\ndef start_training(cfg):\n    set_random_seed(cfg.seed)\n\n    n_gpu = hvd.size()\n    cfg.n_gpu = n_gpu\n    device = torch.device(\"cuda\", hvd.local_rank())\n    torch.cuda.set_device(hvd.local_rank())\n    if hvd.rank() != 0:\n        LOGGER.disabled = True\n    LOGGER.info(\"device: {} n_gpu: {}, rank: {}, \"\n                \"16-bits training: {}\".format(\n                    device, n_gpu, hvd.rank(), bool(cfg.fp16)))\n\n    model = setup_model(cfg, device=device)\n    model.train()\n    optimizer = setup_e2e_optimizer(model, cfg)\n\n    # Horovod: (optional) compression algorithm.compressin\n    compression = hvd.Compression.none\n    optimizer = hvd.DistributedOptimizer(\n        optimizer, named_parameters=model.named_parameters(),\n        compression=compression)\n\n    #  Horovod: broadcast parameters & optimizer state.\n    hvd.broadcast_parameters(model.state_dict(), root_rank=0)\n    hvd.broadcast_optimizer_state(optimizer, root_rank=0)\n\n    model, optimizer = amp.initialize(\n        model, optimizer, enabled=cfg.fp16, opt_level='O2',\n        keep_batchnorm_fp32=True)\n\n    # prepare data\n    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)\n    train_loader, val_loader = setup_dataloaders(cfg, tokenizer)\n    eval_loader = mk_video_ret_eval_dataloader(\n        anno_path=cfg.val_datasets[0].txt,\n        lmdb_dir=cfg.val_datasets[0].img,\n        cfg=cfg, tokenizer=tokenizer,\n    )\n\n    # compute the number of steps and update cfg\n    total_n_examples = len(train_loader.dataset) * cfg.max_n_example_per_group\n    total_train_batch_size = int(\n        n_gpu * cfg.train_batch_size *\n        cfg.gradient_accumulation_steps * cfg.max_n_example_per_group)\n    cfg.num_train_steps = int(math.ceil(\n        1. * cfg.num_train_epochs * total_n_examples / total_train_batch_size))\n\n    cfg.valid_steps = int(math.ceil(\n        1. * cfg.num_train_steps / cfg.num_valid /\n        cfg.min_valid_steps)) * cfg.min_valid_steps\n    actual_num_valid = int(math.floor(\n        1. * cfg.num_train_steps / cfg.valid_steps)) + 1\n\n    # restore\n    restorer = TrainingRestorer(cfg, model, optimizer)\n    global_step = restorer.global_step\n    TB_LOGGER.global_step = global_step\n    if hvd.rank() == 0:\n        LOGGER.info(\"Saving training meta...\")\n        save_training_meta(cfg)\n        LOGGER.info(\"Saving training done...\")\n        TB_LOGGER.create(join(cfg.output_dir, 'log'))\n        pbar = tqdm(total=cfg.num_train_steps)\n        model_saver = ModelSaver(join(cfg.output_dir, \"ckpt\"))\n        add_log_to_file(join(cfg.output_dir, \"log\", \"log.txt\"))\n    else:\n        LOGGER.disabled = True\n        pbar = NoOp()\n        model_saver = NoOp()\n        restorer = NoOp()\n\n    if global_step > 0:\n        pbar.update(global_step)\n\n    LOGGER.info(cfg)\n    LOGGER.info(\"Starting training...\")\n    LOGGER.info(f\"***** Running training with {n_gpu} GPUs *****\")\n    LOGGER.info(f\"  Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}\")\n    LOGGER.info(f\"  max_n_example_per_group = {cfg.max_n_example_per_group}\")\n    LOGGER.info(f\"  Accumulate steps = {cfg.gradient_accumulation_steps}\")\n    LOGGER.info(f\"  Total batch size = #GPUs * Single-GPU batch size * \"\n                f\"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}\")\n    LOGGER.info(f\"  Total #epochs = {cfg.num_train_epochs}\")\n    LOGGER.info(f\"  Total #steps = {cfg.num_train_steps}\")\n    LOGGER.info(f\"  Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times\")\n\n    LOGGER.info(f'Step {global_step}: start validation')\n    validate(\n        model, val_loader, eval_loader, cfg, global_step,\n        eval_filepath=cfg.val_datasets[0].txt)\n\n    # quick hack for amp delay_unscale bug\n    with optimizer.skip_synchronize():\n        optimizer.zero_grad()\n        if global_step == 0:\n            optimizer.step()\n    debug_step = 3\n    running_loss = RunningMeter('train_loss')\n\n    for step, batch in enumerate(InfiniteIterator(train_loader)):\n        # forward pass\n        del batch[\"caption_ids\"]\n        mini_batch = dict()\n        for k, v in batch.items():\n            if k != \"visual_inputs\":\n                mini_batch[k] = v\n\n        pool_method = cfg.score_agg_func\n        # could be 1, where only a single clip is used\n        num_clips = cfg.train_n_clips\n\n        assert num_clips == 1, \"Support only single clip for now.\"\n\n        num_frm = cfg.num_frm\n        # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)\n        bsz = batch[\"visual_inputs\"].shape[0]\n        new_visual_shape = (bsz, num_clips, num_frm) + batch[\"visual_inputs\"].shape[2:]\n        visual_inputs = batch[\"visual_inputs\"].view(*new_visual_shape)\n        model_out = []\n\n        for clip_idx in range(num_clips):\n            # (B, num_frm, C, H, W)\n            mini_batch[\"visual_inputs\"] = visual_inputs[:, clip_idx]\n            mini_batch[\"n_examples_list\"] = batch[\"n_examples_list\"]\n            # outputs = forward_step(model, mini_batch, cfg)\n            outputs = forward_step(model, mini_batch)\n            model_out.append(outputs)\n            # the losses are cross entropy and mse, no need to * num_labels\n\n        loss_itm = outputs['itm_loss']\n        loss_itc = outputs['itc_loss']\n        loss = loss_itm + loss_itc\n\n        running_loss(loss.item())\n        # backward pass\n        delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0\n        with amp.scale_loss(\n                loss, optimizer, delay_unscale=delay_unscale\n                ) as scaled_loss:\n            scaled_loss.backward()\n            zero_none_grad(model)\n            optimizer.synchronize()\n\n        # optimizer\n        if (step + 1) % cfg.gradient_accumulation_steps == 0:\n            global_step += 1\n\n            # learning rate scheduling\n            n_epoch = int(1. * total_train_batch_size * global_step\n                          / total_n_examples)\n\n            # learning rate scheduling cnn\n            lr_this_step = get_lr_sched(\n                global_step, cfg.decay, cfg.learning_rate,\n                cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio,\n                decay_epochs=cfg.step_decay_epochs,\n                multi_step_epoch=n_epoch)\n\n            # Hardcoded param group length\n            for pg_n, param_group in enumerate(\n                    optimizer.param_groups):\n                    param_group['lr'] = lr_this_step\n\n            if step % cfg.log_interval == 0:\n                TB_LOGGER.add_scalar(\n                    \"train/lr\", lr_this_step, global_step)\n\n            TB_LOGGER.add_scalar('train/loss', running_loss.val, global_step)\n\n            # update model params\n            if cfg.grad_norm != -1:\n                grad_norm = clip_grad_norm_(\n                    amp.master_params(optimizer),\n                    cfg.grad_norm)\n                TB_LOGGER.add_scalar(\n                    \"train/grad_norm\", grad_norm, global_step)\n            TB_LOGGER.step()\n\n            # Check if there is None grad\n            none_grads = [\n                p[0] for p in model.named_parameters()\n                if p[1].requires_grad and p[1].grad is None]\n\n            assert len(none_grads) == 0, f\"{none_grads}\"\n\n            with optimizer.skip_synchronize():\n                optimizer.step()\n                optimizer.zero_grad()\n            restorer.step()\n            pbar.update(1)\n\n            # checkpoint\n            if global_step % cfg.valid_steps == 0:\n                LOGGER.info(f'Step {global_step}: start validation')\n                validate(\n                    model, val_loader, eval_loader, cfg, global_step,\n                    eval_filepath=cfg.val_datasets[0].txt)\n                model_saver.save(step=global_step, model=model)\n        if global_step >= cfg.num_train_steps:\n            break\n\n        if cfg.debug and global_step >= debug_step:\n            break\n\n    if global_step % cfg.valid_steps != 0:\n        LOGGER.info(f'Step {global_step}: start validation')\n        validate(\n            model, val_loader, eval_loader, cfg, global_step,\n            eval_filepath=cfg.val_datasets[0].txt)\n        model_saver.save(step=global_step, model=model)\n\n\ndef get_retrieval_metric_from_bool_matrix(bool_matrix):\n    \"\"\" Calc Recall@K, median rank and mean rank.\n    Args:\n        bool_matrix: np array of shape (#txt, #vid), np.bool,\n            sorted row-wise from most similar to less similar.\n            The GT position is marked as 1, while all the others are 0,\n            each row will only have one 1.\n\n    Returns:\n        retrieval_metrics: dict(\n            R1=.., R5=..., R10=..., MedR=..., MeanR=...\n        )\n    \"\"\"\n    num_row = bool_matrix.shape[0]  # #rows\n    row_range, gt_ranks = np.where(bool_matrix == 1)\n    assert np.allclose(row_range, np.arange(len(row_range))), \\\n        \"each row should only a single GT\"\n    retrieval_metrics = dict(\n        r1=100 * bool_matrix[:, 0].sum() / num_row,\n        r5=100 * bool_matrix[:, :5].sum() / num_row,\n        r10=100 * bool_matrix[:, :10].sum() / num_row,\n        medianR=np.median(gt_ranks+1),  # convert to 1-indexed system instead of 0-indexed.\n        meanR=np.mean(gt_ranks+1)\n    )\n    return retrieval_metrics\n\n\ndef get_retrieval_scores(score_matrix, gt_row2col_id_mapping, row_idx2id, col_id2idx):\n    # rank scores\n    score_matrix_sorted, indices_sorted = \\\n        torch.sort(score_matrix, dim=1, descending=True)  # (#txt, #vid)\n\n    # build bool matrix, where the GT position is marked as 1, all the others are 0,\n    num_row = len(score_matrix)\n    gt_col_indices = torch.zeros(num_row, 1)\n    for idx in range(num_row):\n        gt_col_id = gt_row2col_id_mapping[row_idx2id[idx]]\n        gt_col_indices[idx, 0] = col_id2idx[gt_col_id]\n\n    bool_matrix = indices_sorted == gt_col_indices  # (#txt, #vid)\n    retrieval_metrics = get_retrieval_metric_from_bool_matrix(bool_matrix.numpy())\n    return retrieval_metrics\n\n\ndef eval_retrieval(vid_txt_score_dicts, gt_txt_id2vid_id, id2data):\n    \"\"\"\n    Args:\n        vid_txt_score_dicts: list(dict), each dict is dict(vid_id=..., txt_id=..., score=...)\n        gt_txt_id2vid_id: dict, ground-truth {txt_id: vid_id}\n        id2data: dict, {txt_id: single_example_data}\n\n    Returns:\n\n    \"\"\"\n    # group prediction by txt_id\n    scores_group_by_txt_ids = defaultdict(list)\n    for d in vid_txt_score_dicts:\n        scores_group_by_txt_ids[d[\"txt_id\"]].append(d)\n\n    # clean duplicated videos\n    _scores_group_by_txt_ids = defaultdict(list)\n    for txt_id, txt_vid_pairs in scores_group_by_txt_ids.items():\n        added_vid_ids = []\n        for d in txt_vid_pairs:\n            if d[\"vid_id\"] not in added_vid_ids:\n                _scores_group_by_txt_ids[txt_id].append(d)\n                added_vid_ids.append(d[\"vid_id\"])\n    scores_group_by_txt_ids = _scores_group_by_txt_ids\n\n    num_txt = len(scores_group_by_txt_ids)\n    any_key = list(scores_group_by_txt_ids.keys())[0]\n    vid_ids = [d[\"vid_id\"] for d in scores_group_by_txt_ids[any_key]]\n    num_vid = len(vid_ids)\n    assert len(set(vid_ids)) == num_vid, \"Each caption should be compared to each video only once.\"\n    for k, v in scores_group_by_txt_ids.items():\n        assert num_vid == len(v), \"each captions should be compared with the same #videos.\"\n\n    # row/col indices in the score matrix\n    # *_id are the original ids, *_idx are the matrix indices\n    txt_id2idx = {txt_id: idx for idx, txt_id in enumerate(scores_group_by_txt_ids)}\n    vid_id2idx = {vid_id: idx for idx, vid_id in enumerate(vid_ids)}\n    txt_idx2id = {v: k for k, v in txt_id2idx.items()}\n    vid_idx2id = {v: k for k, v in vid_id2idx.items()}\n\n    # build score (float32) and vid_id (str) matrix\n    score_matrix = torch.zeros(num_txt, num_vid)\n    sim_matrix = torch.zeros(num_txt, num_vid)\n    for txt_id, preds in scores_group_by_txt_ids.items():\n        txt_idx = txt_id2idx[txt_id]\n        for p in preds:\n            vid_idx = vid_id2idx[p[\"vid_id\"]]\n            score_matrix[txt_idx, vid_idx] = p[\"score\"]\n            sim_matrix[txt_idx, vid_idx] = p['sim']\n\n    # [dxli] discard pairs with low ITC similarity scores\n    # top_k, indices = torch.topk(sim_matrix, dim=1, k=100)\n    # new_sim_matrix = torch.zeros_like(sim_matrix)\n    # new_sim_matrix = new_sim_matrix.scatter(1, indices, top_k)\n    # score_matrix[new_sim_matrix == 0] = 0\n\n    # text to video retrieval, score_matrix--> (#txt, #vid)\n    # given a text, retrieve most relevant videos\n    t2v_retrieval_metrics = get_retrieval_scores(\n        score_matrix, gt_txt_id2vid_id, txt_idx2id, vid_id2idx)\n    # video to text retrieval, score_matrix--> (#vid, #txt)\n    # given a video, retrieve most relevant videos\n    score_matrix = score_matrix.transpose(0, 1)\n    gt_vid_id2txt_id = {v: k for k, v in gt_txt_id2vid_id.items()}\n    v2t_retrieval_metrics = get_retrieval_scores(\n        score_matrix, gt_vid_id2txt_id, vid_idx2id, txt_id2idx)\n    retrieval_metrics = dict(\n        text2video=t2v_retrieval_metrics,\n        video2text=v2t_retrieval_metrics\n    )\n    return retrieval_metrics\n\n\n@torch.no_grad()\ndef inference_retrieval(model, val_loader, eval_file_path, cfg):\n    model.eval()\n    retrieval_res = []  # list(dict): dict(vid_id=..., txt_id=..., score=...)\n    st = time.time()\n    eval_bsz = cfg.inference_batch_size if cfg.do_inference else cfg.eval_retrieval_batch_size\n    LOGGER.info(f\"Evaluate retrieval #video per GPU: {len(val_loader)}\")\n    if hvd.rank() == 0:\n        pbar = tqdm(total=len(val_loader), desc=\"eval\")\n\n    for batch in val_loader:\n        # each batch contains 1 video and N (=1000) captions\n        n_mini_batches = math.ceil(len(batch[\"caption_ids\"]) / eval_bsz)\n        vid_id = batch[\"vid_id\"]\n        for idx in range(n_mini_batches):\n            # compile shared text part\n            mini_batch = dict()\n            for k in [\"text_input_ids\", \"text_input_mask\", \"labels\"]:\n                if batch[k] is not None:\n                    mini_batch[k] = batch[k][idx * eval_bsz:(idx + 1) * eval_bsz]\n                else:\n                    mini_batch[k] = None\n            caption_ids = batch[\"caption_ids\"][idx * eval_bsz:(idx + 1) * eval_bsz]\n            # bsz = len(caption_ids)\n            mini_batch[\"n_examples_list\"] = [len(caption_ids)]\n\n            num_clips = cfg.inference_n_clips\n            num_frm = cfg.num_frm\n            # (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)\n            new_visual_shape = (1, num_clips, num_frm) + batch[\"visual_inputs\"].shape[2:]\n            visual_inputs = batch[\"visual_inputs\"].view(*new_visual_shape)\n            logits = []\n            sim_scores = []\n            for clip_idx in range(num_clips):\n                mini_batch[\"visual_inputs\"] = visual_inputs[:, clip_idx]\n                if cfg.fp16:\n                    # FIXME not sure why we need to do this explicitly?\n                    mini_batch[\"visual_inputs\"] = mini_batch[\"visual_inputs\"].half()\n                outputs = forward_inference_step(model, mini_batch)\n                logits.append(outputs[\"logits\"].cpu())\n                sim_scores.append(outputs[\"itc_scores\"].cpu())\n\n            logits = torch.stack(logits)  # (num_frm, B, 1 or 2)\n            sim_scores = torch.stack(sim_scores)\n            \n            # FIXME not sure why need to convert dtype explicitly\n            logits = logits.squeeze().float()\n            sim_scores = sim_scores.squeeze().float().tolist()\n            if logits.shape[1] == 2:\n                # [dxli] uses 1 for positive and 0 for negative.\n                # therefore we choose dim=1\n                probs = F.softmax(logits, dim=1)[:, 1].tolist()\n            else:\n                raise NotImplementedError('Not supported (unclear purposes)!')\n            for cap_id, score, sim in zip(caption_ids, probs, sim_scores):\n                retrieval_res.append(dict(\n                    vid_id=vid_id,\n                    txt_id=cap_id,\n                    score=round(score, 4),\n                    sim=round(sim, 4)\n                ))\n\n        if hvd.rank() == 0:\n            pbar.update(1)\n\n    # ###### Saving with Horovod ####################\n    # dummy sync\n    _ = None\n    all_gather_list(_)\n    n_gpu = hvd.size()\n    eval_dir = join(cfg.output_dir, f\"results_{os.path.splitext(os.path.basename(eval_file_path))[0]}\")\n    os.makedirs(eval_dir, exist_ok=True)\n    if n_gpu > 1:\n        # with retrial, as azure blob fails occasionally.\n        max_save_load_trial = 10\n        save_trial = 0\n        while save_trial < max_save_load_trial:\n            try:\n                LOGGER.info(f\"Save results trial NO. {save_trial}\")\n                save_json(\n                    retrieval_res,\n                    join(eval_dir, f\"tmp_results_rank{hvd.rank()}.json\"))\n                break\n            except Exception as e:\n                print(f\"Saving exception: {e}\")\n                save_trial += 1\n\n    # dummy sync\n    _ = None\n    all_gather_list(_)\n    # join results\n    if n_gpu > 1 and hvd.rank() == 0:\n        retrieval_res = []\n        for rk in range(n_gpu):\n            retrieval_res.extend(load_json(\n                join(eval_dir, f\"tmp_results_rank{rk}.json\")))\n        LOGGER.info('results joined')\n\n    if hvd.rank() == 0:\n        retrieval_metrics = eval_retrieval(\n            retrieval_res, val_loader.dataset.gt_cap_id2vid_id, val_loader.dataset.id2data)\n        LOGGER.info(f\"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_metrics}\")\n    else:\n        retrieval_metrics = None\n\n    model.train()\n    return retrieval_res, retrieval_metrics\n\n\ndef start_inference(cfg):\n    set_random_seed(cfg.seed)\n    n_gpu = hvd.size()\n    device = torch.device(\"cuda\", hvd.local_rank())\n    torch.cuda.set_device(hvd.local_rank())\n    if hvd.rank() != 0:\n        LOGGER.disabled = True\n\n    inference_res_dir = join(\n        cfg.output_dir,\n        f\"results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/\"\n        f\"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}\"\n    )\n\n    if hvd.rank() == 0:\n        os.makedirs(inference_res_dir, exist_ok=True)\n        save_json(cfg, join(inference_res_dir, \"raw_args.json\"),\n                  save_pretty=True)\n\n    LOGGER.info(\"device: {} n_gpu: {}, rank: {}, \"\n                \"16-bits training: {}\".format(\n                    device, n_gpu, hvd.rank(), bool(cfg.fp16)))\n\n    # overwrite cfg with stored_cfg,\n    # but skip keys containing the keyword 'inference'\n    stored_cfg_path = join(cfg.output_dir, \"log/args.json\")\n    stored_cfg = edict(load_json(stored_cfg_path))\n    for k, v in cfg.items():\n        if k in stored_cfg and \"inference\" not in k and \"output_dir\" not in k:\n            setattr(cfg, k, stored_cfg[k])\n\n    # setup models\n    cfg.model_config = join(cfg.output_dir, \"log/model_config.json\")\n    e2e_weights_path = join(\n        cfg.output_dir, f\"ckpt/model_step_{cfg.inference_model_step}.pt\")\n    if exists(e2e_weights_path):\n        cfg.e2e_weights_path = e2e_weights_path\n    else:\n        raise NotImplementedError(\"Not supporting loading separate weights for inference.\")\n    model = setup_model(cfg, device=device)\n    model.eval()\n\n    # FIXME separate scaling for each loss\n    model = amp.initialize(\n        model, enabled=cfg.fp16, opt_level='O2')\n\n    global_step = 0\n    # prepare data\n    tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)\n    cfg.data_ratio = 1.\n\n    val_loader = mk_video_ret_eval_dataloader(\n        anno_path=cfg.inference_txt_db,\n        lmdb_dir=cfg.inference_img_db,\n        cfg=cfg, tokenizer=tokenizer,\n    )\n\n    LOGGER.info(cfg)\n    LOGGER.info(\"Starting inference...\")\n    LOGGER.info(f\"***** Running inference with {n_gpu} GPUs *****\")\n    LOGGER.info(f\"  Batch size = {cfg.inference_batch_size}\")\n\n    LOGGER.info(f'Step {global_step}: start validation')\n    ret_results, ret_scores = inference_retrieval(\n        model, val_loader, cfg.inference_txt_db, cfg)\n\n    if hvd.rank() == 0:\n        save_json(cfg, join(inference_res_dir, \"merged_args.json\"),\n                  save_pretty=True)\n        save_json(ret_results, join(inference_res_dir, \"results.json\"),\n                  save_pretty=True)\n        save_json(ret_scores, join(inference_res_dir, \"scores.json\"),\n                  save_pretty=True)\n\n\nif __name__ == '__main__':\n    # Initialize Horovod\n    hvd.init()\n    input_cfg = shared_configs.get_video_retrieval_args()\n    if input_cfg.do_inference:\n        start_inference(input_cfg)\n    else:\n        start_training(input_cfg)\n"
  },
  {
    "path": "src/utils/basic_utils.py",
    "content": "import os\nimport ujson as json\nimport zipfile\nimport numpy as np\nimport pickle\n\nimport pandas as pd\n\n\ndef load_pickle(filename):\n    with open(filename, \"rb\") as f:\n        return pickle.load(f)\n\n\ndef save_pickle(data, filename):\n    with open(filename, \"wb\") as f:\n        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)\n\n\ndef load_json(filename):\n    with open(filename, \"r\") as f:\n        return json.load(f)\n\n\ndef save_json(data, filename, save_pretty=False, sort_keys=False):\n    with open(filename, \"w\") as f:\n        if save_pretty:\n            f.write(json.dumps(data, indent=4, sort_keys=sort_keys))\n        else:\n            json.dump(data, f)\n\n\ndef load_jsonl(filename):\n    with open(filename, \"r\") as f:\n        return [json.loads(l.strip(\"\\n\")) for l in f.readlines()]\n\n\ndef save_jsonl(data, filename):\n    \"\"\"data is a list\"\"\"\n    with open(filename, \"w\") as f:\n        f.write(\"\\n\".join([json.dumps(e) for e in data]))\n\n\ndef concat_json_list(filepaths, save_path):\n    json_lists = []\n    for p in filepaths:\n        json_lists += load_json(p)\n    save_json(json_lists, save_path)\n\n\ndef save_lines(list_of_str, filepath):\n    with open(filepath, \"w\") as f:\n        f.write(\"\\n\".join(list_of_str))\n\n\ndef read_lines(filepath):\n    with open(filepath, \"r\") as f:\n        return [e.strip(\"\\n\") for e in f.readlines()]\n\n\ndef mkdirp(p):\n    if not os.path.exists(p):\n        os.makedirs(p)\n\n\ndef flat_list_of_lists(l):\n    \"\"\"flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]\"\"\"\n    return [item for sublist in l for item in sublist]\n\n\ndef convert_to_seconds(hms_time):\n    \"\"\" convert '00:01:12' to 72 seconds.\n    :hms_time (str): time in comma separated string, e.g. '00:01:12'\n    :return (int): time in seconds, e.g. 72\n    \"\"\"\n    times = [float(t) for t in hms_time.split(\":\")]\n    return times[0] * 3600 + times[1] * 60 + times[2]\n\n\ndef get_video_name_from_url(url):\n    return url.split(\"/\")[-1][:-4]\n\n\ndef merge_dicts(list_dicts):\n    merged_dict = list_dicts[0].copy()\n    for i in range(1, len(list_dicts)):\n        merged_dict.update(list_dicts[i])\n    return merged_dict\n\n\ndef l2_normalize_np_array(np_array, eps=1e-5):\n    \"\"\"np_array: np.ndarray, (*, D), where the last dim will be normalized\"\"\"\n    return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps)\n\n\ndef make_zipfile(src_dir, save_path, enclosing_dir=\"\", exclude_dirs=None, exclude_extensions=None,\n                 exclude_dirs_substring=None):\n    \"\"\"make a zip file of root_dir, save it to save_path.\n    exclude_paths will be excluded if it is a subdir of root_dir.\n    An enclosing_dir is added is specified.\n    \"\"\"\n    abs_src = os.path.abspath(src_dir)\n    with zipfile.ZipFile(save_path, \"w\") as zf:\n        for dirname, subdirs, files in os.walk(src_dir):\n            if exclude_dirs is not None:\n                for e_p in exclude_dirs:\n                    if e_p in subdirs:\n                        subdirs.remove(e_p)\n            if exclude_dirs_substring is not None:\n                to_rm = []\n                for d in subdirs:\n                    if exclude_dirs_substring in d:\n                        to_rm.append(d)\n                for e in to_rm:\n                    subdirs.remove(e)\n            arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:])\n            zf.write(dirname, arcname)\n            for filename in files:\n                if exclude_extensions is not None:\n                    if os.path.splitext(filename)[1] in exclude_extensions:\n                        continue  # do not zip it\n                absname = os.path.join(dirname, filename)\n                arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:])\n                zf.write(absname, arcname)\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current/max/min value\"\"\"\n    def __init__(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n        self.max = -1e10\n        self.min = 1e10\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n        self.max = -1e10\n        self.min = 1e10\n\n    def update(self, val, n=1):\n        self.max = max(val, self.max)\n        self.min = min(val, self.min)\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True):\n    \"\"\"Dissect an array (N, D) into a list a sub-array,\n    np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept\"\"\"\n    if assert_equal:\n        assert len(np_array) == sum(lengths)\n    length_indices = [0, ]\n    for i in range(len(lengths)):\n        length_indices.append(length_indices[i] + lengths[i])\n    if dim == 0:\n        array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))]\n    elif dim == 1:\n        array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]\n    elif dim == 2:\n        array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]\n    else:\n        raise NotImplementedError\n    return array_list\n\n\ndef get_ratio_from_counter(counter_obj, threshold=200):\n    keys = counter_obj.keys()\n    values = counter_obj.values()\n    filtered_values = [counter_obj[k] for k in keys if k > threshold]\n    return float(sum(filtered_values)) / sum(values)\n\n\ndef get_rounded_percentage(float_number, n_floats=2):\n    return round(float_number * 100, n_floats)\n\n\ndef read_dataframe(pkl_path):\n    return pd.read_pickle(pkl_path)\n\n\ndef save_frames_grid(img_array, out_path):\n    import torch\n    from torchvision.utils import make_grid\n    from PIL import Image\n\n    if len(img_array.shape) == 3:\n        img_array = img_array.unsqueeze(0)\n    elif len(img_array.shape) == 5:\n        b, t, c, h, w = img_array.shape\n        img_array = img_array.view(-1, c, h, w)\n    elif len(img_array.shape) == 4:\n        pass\n    else:\n        raise NotImplementedError('Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored.')\n    \n    assert img_array.shape[1] == 3, \"Exepcting input shape of (3, H, W), i.e. RGB-only.\"\n    \n    grid = make_grid(img_array)\n    ndarr = grid.permute(1, 2, 0).to('cpu', torch.uint8).numpy()\n\n    img = Image.fromarray(ndarr)\n\n    img.save(out_path)\n    "
  },
  {
    "path": "src/utils/distributed.py",
    "content": "\"\"\"\r\nCopyright (c) Microsoft Corporation.\r\nLicensed under the MIT license.\r\ndistributed API using Horovod\r\nModified from OpenNMT's native pytorch distributed utils\r\n(https://github.com/OpenNMT/OpenNMT-py)\r\n\"\"\"\r\n\r\nimport math\r\nimport pickle\r\n\r\nimport torch\r\nfrom horovod import torch as hvd\r\nfrom horovod.torch.mpi_ops import rank, size\r\n\r\n\r\ndef all_reduce_and_rescale_tensors(tensors, rescale_denom):\r\n    \"\"\"All-reduce and rescale tensors at once (as a flattened tensor)\r\n    Args:\r\n        tensors: list of Tensors to all-reduce\r\n        rescale_denom: denominator for rescaling summed Tensors\r\n    \"\"\"\r\n    # buffer size in bytes, determine equiv. # of elements based on data type\r\n    sz = sum(t.numel() for t in tensors)\r\n    buffer_t = tensors[0].new(sz).zero_()\r\n\r\n    # copy tensors into buffer_t\r\n    offset = 0\r\n    for t in tensors:\r\n        numel = t.numel()\r\n        buffer_t[offset:offset+numel].copy_(t.view(-1))\r\n        offset += numel\r\n\r\n    # all-reduce and rescale\r\n    hvd.allreduce_(buffer_t[:offset])\r\n    buffer_t.div_(rescale_denom)\r\n\r\n    # copy all-reduced buffer back into tensors\r\n    offset = 0\r\n    for t in tensors:\r\n        numel = t.numel()\r\n        t.view(-1).copy_(buffer_t[offset:offset+numel])\r\n        offset += numel\r\n\r\n\r\ndef all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom,\r\n                                           buffer_size=10485760):\r\n    \"\"\"All-reduce and rescale tensors in chunks of the specified size.\r\n    Args:\r\n        tensors: list of Tensors to all-reduce\r\n        rescale_denom: denominator for rescaling summed Tensors\r\n        buffer_size: all-reduce chunk size in bytes\r\n    \"\"\"\r\n    # buffer size in bytes, determine equiv. # of elements based on data type\r\n    buffer_t = tensors[0].new(\r\n        math.ceil(buffer_size / tensors[0].element_size())).zero_()\r\n    buffer = []\r\n\r\n    def all_reduce_buffer():\r\n        # copy tensors into buffer_t\r\n        offset = 0\r\n        for t in buffer:\r\n            numel = t.numel()\r\n            buffer_t[offset:offset+numel].copy_(t.view(-1))\r\n            offset += numel\r\n\r\n        # all-reduce and rescale\r\n        hvd.allreduce_(buffer_t[:offset])\r\n        buffer_t.div_(rescale_denom)\r\n\r\n        # copy all-reduced buffer back into tensors\r\n        offset = 0\r\n        for t in buffer:\r\n            numel = t.numel()\r\n            t.view(-1).copy_(buffer_t[offset:offset+numel])\r\n            offset += numel\r\n\r\n    filled = 0\r\n    for t in tensors:\r\n        sz = t.numel() * t.element_size()\r\n        if sz > buffer_size:\r\n            # tensor is bigger than buffer, all-reduce and rescale directly\r\n            hvd.allreduce_(t)\r\n            t.div_(rescale_denom)\r\n        elif filled + sz > buffer_size:\r\n            # buffer is full, all-reduce and replace buffer with grad\r\n            all_reduce_buffer()\r\n            buffer = [t]\r\n            filled = sz\r\n        else:\r\n            # add tensor to buffer\r\n            buffer.append(t)\r\n            filled += sz\r\n\r\n    if len(buffer) > 0:\r\n        all_reduce_buffer()\r\n\r\n\r\ndef broadcast_tensors(tensors, root_rank, buffer_size=10485760):\r\n    \"\"\"broadcast tensors in chunks of the specified size.\r\n    Args:\r\n        tensors: list of Tensors to broadcast\r\n        root_rank: rank to broadcast\r\n        buffer_size: all-reduce chunk size in bytes\r\n    \"\"\"\r\n    # buffer size in bytes, determine equiv. # of elements based on data type\r\n    buffer_t = tensors[0].new(\r\n        math.ceil(buffer_size / tensors[0].element_size())).zero_()\r\n    buffer = []\r\n\r\n    def broadcast_buffer():\r\n        # copy tensors into buffer_t\r\n        offset = 0\r\n        for t in buffer:\r\n            numel = t.numel()\r\n            buffer_t[offset:offset+numel].copy_(t.view(-1))\r\n            offset += numel\r\n\r\n        # broadcast\r\n        hvd.broadcast_(buffer_t[:offset], root_rank)\r\n\r\n        # copy all-reduced buffer back into tensors\r\n        offset = 0\r\n        for t in buffer:\r\n            numel = t.numel()\r\n            t.view(-1).copy_(buffer_t[offset:offset+numel])\r\n            offset += numel\r\n\r\n    filled = 0\r\n    for t in tensors:\r\n        sz = t.numel() * t.element_size()\r\n        if sz > buffer_size:\r\n            # tensor is bigger than buffer, broadcast directly\r\n            hvd.broadcast_(t, root_rank)\r\n        elif filled + sz > buffer_size:\r\n            # buffer is full, broadcast and replace buffer with tensor\r\n            broadcast_buffer()\r\n            buffer = [t]\r\n            filled = sz\r\n        else:\r\n            # add tensor to buffer\r\n            buffer.append(t)\r\n            filled += sz\r\n\r\n    if len(buffer) > 0:\r\n        broadcast_buffer()\r\n\r\n\r\ndef all_gather_list(data, max_size=4096):\r\n    \"\"\"Gathers arbitrary data from all nodes into a list.\"\"\"\r\n    world_size = hvd.size()\r\n    if not hasattr(all_gather_list, '_in_buffer') or \\\r\n            max_size != all_gather_list._in_buffer.size():\r\n        all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)\r\n    in_buffer = all_gather_list._in_buffer\r\n\r\n    enc = pickle.dumps(data)\r\n    enc_size = len(enc)\r\n    if enc_size + 2 > max_size:\r\n        raise ValueError(\r\n            'encoded data exceeds max_size: {}'.format(enc_size + 2))\r\n    assert max_size < 255*256\r\n    in_buffer[0] = enc_size // 255  # this encoding works for max_size < 65k\r\n    in_buffer[1] = enc_size % 255\r\n    in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))\r\n\r\n    # FIXME cannot create buffer\r\n    out = hvd.allgather(in_buffer.cuda())\r\n\r\n    results = []\r\n    for i in range(0, max_size*world_size, max_size):\r\n        out_buffer = out[i:i+max_size]\r\n        size = (255 * out_buffer[0].item()) + out_buffer[1].item()\r\n\r\n        bytes_list = bytes(out_buffer[2:size+2].tolist())\r\n        result = pickle.loads(bytes_list)\r\n        results.append(result)\r\n    return results\r\n\r\n\r\ndef any_broadcast(data, root_rank, max_size=4096):\r\n    \"\"\"broadcast arbitrary data from root_rank to all nodes.\"\"\"\r\n    if not hasattr(any_broadcast, '_in_buffer') or \\\r\n            max_size != any_broadcast._in_buffer.size():\r\n        any_broadcast._buffer = torch.cuda.ByteTensor(max_size)\r\n    buffer_ = any_broadcast._buffer\r\n\r\n    enc = pickle.dumps(data)\r\n    enc_size = len(enc)\r\n    if enc_size + 2 > max_size:\r\n        raise ValueError(\r\n            'encoded data exceeds max_size: {}'.format(enc_size + 2))\r\n    assert max_size < 255*256\r\n    buffer_[0] = enc_size // 255  # this encoding works for max_size < 65k\r\n    buffer_[1] = enc_size % 255\r\n    buffer_[2:enc_size+2] = torch.ByteTensor(list(enc))\r\n\r\n    hvd.broadcast_(buffer_, root_rank)\r\n\r\n    size = (255 * buffer_[0].item()) + buffer_[1].item()\r\n\r\n    bytes_list = bytes(buffer_[2:size+2].tolist())\r\n    result = pickle.loads(bytes_list)\r\n    return result\r\n\r\ndef allgather_object(obj, name=None):\r\n    \"\"\"\r\n    Serializes and allgathers an object from all other processes.\r\n\r\n    Arguments:\r\n        obj: An object capable of being serialized without losing any context.\r\n        name: Optional name to use during allgather, will default to the class\r\n              type.\r\n\r\n    Returns:\r\n        The list of objects that were allgathered across all ranks.\r\n    \"\"\"\r\n    import io\r\n    import cloudpickle\r\n\r\n    if name is None:\r\n        name = type(obj).__name__\r\n\r\n    def load(byte_array):\r\n        buf = io.BytesIO(byte_array.tobytes())\r\n        return cloudpickle.load(buf)\r\n\r\n    b = io.BytesIO()\r\n    cloudpickle.dump(obj, b)\r\n\r\n    t = torch.ByteTensor(bytearray(b.getvalue()))\r\n    sz = torch.IntTensor([t.shape[0]])\r\n\r\n    sizes = hvd.allgather(sz, name=name + '.sz').numpy()\r\n    gathered = hvd.allgather(t, name=name + '.t').numpy()\r\n\r\n    def select(i):\r\n        start = sum(sizes[:i])\r\n        end = start + sizes[i]\r\n        return gathered[start:end]\r\n\r\n    return [load(select(i)) for i in range(size())]"
  },
  {
    "path": "src/utils/grad_ckpt.py",
    "content": "import torch\nimport warnings\n\n\ndef detach_variable(inputs):\n    if isinstance(inputs, tuple):\n        out = []\n        for inp in inputs:\n            x = inp.detach()\n            x.requires_grad = inp.requires_grad\n            out.append(x)\n        return tuple(out)\n    else:\n        raise RuntimeError(\n            \"Only tuple of tensors is supported. Got Unsupported input type: \", type(inputs).__name__)\n\n\ndef check_backward_validity(inputs):\n    if not any(inp.requires_grad for inp in inputs):\n        warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n\n\nclass CheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n        with torch.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        for i in range(len(ctx.input_tensors)):\n            temp = ctx.input_tensors[i]\n            ctx.input_tensors[i] = temp.detach()\n            ctx.input_tensors[i].requires_grad = temp.requires_grad\n        with torch.enable_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)\n        return (None, None) + input_grads"
  },
  {
    "path": "src/utils/load_save.py",
    "content": "\"\"\"\nsaving utilities\n\"\"\"\nimport json\nimport os\nfrom os.path import dirname, exists, join, realpath\nimport subprocess\nfrom apex import amp\nfrom easydict import EasyDict as edict\n\nimport torch\nfrom src.utils.basic_utils import save_json, make_zipfile, load_json\nfrom src.utils.logger import LOGGER\nfrom typing import Any, Dict, Union\n\nfrom src.modeling.timesformer.helpers import resize_spatial_embedding, resize_temporal_embedding\n\n\ndef save_training_meta(args):\n    # args is an EasyDict object, treat it the same as a normal dict\n    os.makedirs(join(args.output_dir, 'log'), exist_ok=True)\n    os.makedirs(join(args.output_dir, 'ckpt'), exist_ok=True)\n\n    # training args\n    save_args_path = join(args.output_dir, 'log', 'args.json')\n    save_json(args, save_args_path, save_pretty=True)\n\n    # model args\n    model_config = json.load(open(args.model_config))\n    save_model_config_path = join(args.output_dir, 'log', 'model_config.json')\n    save_json(model_config, save_model_config_path, save_pretty=True)\n\n    # save a copy of the codebase. !!!Do not store heavy file in your codebase when using it.\n    code_dir = dirname(dirname(dirname(os.path.realpath(__file__))))\n    code_zip_filename = os.path.join(args.output_dir, \"code.zip\")\n    LOGGER.info(f\"Saving code from {code_dir} to {code_zip_filename}...\")\n    make_zipfile(code_dir, code_zip_filename,\n                 enclosing_dir=\"code\",\n                 exclude_dirs_substring=\"results\",\n                 exclude_dirs=[\"__pycache__\", \"output\", \"data\", \"ext\"],\n                 exclude_extensions=[\".pyc\", \".ipynb\", \".swap\", \".pt\"])\n    LOGGER.info(f\"Saving code done.\")\n\n\nclass ModelSaver(object):\n    def __init__(self, output_dir):\n        self.output_dir = output_dir\n        self.max_save_load_trial = 10\n\n    def save(self, step, model, optimizer=None, prefix=\"model\"):\n        model_path = join(self.output_dir, f\"{prefix}_step_{step}.pt\")\n        state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v\n                      for k, v in model.state_dict().items()}\n        # with retrial, as azure blob fails occasionally.\n        save_trial = 0\n        while save_trial < self.max_save_load_trial:\n            try:\n                LOGGER.info(f\"ModelSaver save trial NO. {save_trial}\")\n                torch.save(state_dict, model_path)\n                if optimizer is not None:\n                    optimizer_state_dict = \\\n                        {k: v.cpu() if isinstance(v, torch.Tensor) else v\n                         for k, v in optimizer.state_dict().items()}\n                    dump = {'step': step, 'optimizer': optimizer_state_dict}\n                    torch.save(\n                        dump,\n                        f'{self.output_dir}/{prefix}_step_{step}_train_state.pt')\n                break\n            except Exception as e:\n                save_trial += 1\n\n\ndef load_state_dict_with_pos_embed_resizing(model, loaded_state_dict_or_path, \n                                                    num_patches, num_frames, \n                                                    spatial_embed_key='visual_encoder.model.pos_embed', \n                                                    temporal_embed_key='visual_encoder.model.time_embed',\n                                                    strict=False,\n                                                    remove_text_encoder_prefix=False\n                                                    ):\n    \"\"\"operated in-place, no need to return `model`,\n    \n    Used to load e2e model checkpoints.\n\n    remove_text_encoder_prefix: set to True, when finetune downstream models from pre-trained checkpoints.\n    \"\"\"\n\n    if isinstance(loaded_state_dict_or_path, str):\n        loaded_state_dict = torch.load(\n            loaded_state_dict_or_path, map_location=\"cpu\")\n        \n    else:\n        loaded_state_dict = loaded_state_dict_or_path\n\n    new_state_dict = loaded_state_dict.copy()\n\n    for key in loaded_state_dict:\n        if 'text_encoder.bert' in key and remove_text_encoder_prefix:\n            new_key = key.replace('text_encoder.bert','text_encoder')\n            new_state_dict[new_key] = new_state_dict.pop(key)\n\n    loaded_state_dict = new_state_dict\n\n    ## Resizing spatial embeddings in case they don't match\n    if num_patches + 1 != loaded_state_dict[spatial_embed_key].size(1):\n        loaded_state_dict[spatial_embed_key] = resize_spatial_embedding(loaded_state_dict, spatial_embed_key, num_patches)\n    else:\n        LOGGER.info('The length of spatial position embedding matches. No need to resize.')\n\n    ## Resizing time embeddings in case they don't match\n    if temporal_embed_key in loaded_state_dict and num_frames != loaded_state_dict[temporal_embed_key].size(1):\n        loaded_state_dict[temporal_embed_key] = resize_temporal_embedding(loaded_state_dict, temporal_embed_key, num_frames)\n    else:\n        LOGGER.info('No temporal encoding found. Or the length of temporal position embedding matches. No need to resize.')\n\n    model_keys = set([k for k in list(model.state_dict().keys())])\n    load_keys = set(loaded_state_dict.keys())\n\n    toload = {}\n    mismatched_shape_keys = []\n    for k in model_keys:\n        if k in load_keys:\n            if model.state_dict()[k].shape != loaded_state_dict[k].shape:\n                mismatched_shape_keys.append(k)\n            else:\n                toload[k] = loaded_state_dict[k]\n\n    LOGGER.info(\"You can ignore the keys with `num_batches_tracked` or from task heads\")\n    LOGGER.info(\"Keys in loaded but not in model:\")\n    diff_keys = load_keys.difference(model_keys)\n    LOGGER.info(f\"In total {len(diff_keys)}, {sorted(diff_keys)}\")\n    LOGGER.info(\"Keys in model but not in loaded:\")\n    diff_keys = model_keys.difference(load_keys)\n    LOGGER.info(f\"In total {len(diff_keys)}, {sorted(diff_keys)}\")\n    LOGGER.info(\"Keys in model and loaded, but shape mismatched:\")\n    LOGGER.info(f\"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}\")\n    model.load_state_dict(toload, strict=strict)\n\ndef compare_dict_difference(dict1, dict2, dict1_name=\"dict1\",\n                            dict2_name=\"dict2\",\n                            print_value_diff=True, verbose=False):\n    \"\"\"\n    Args:\n        dict1:\n        dict2:\n        dict1_name:\n        dict2_name:\n        print_value_diff: bool, output dict value difference within shared keys\n            for dict1 and dict2. In effect only when verbose == True\n        verbose:\n    \"\"\"\n    keys1 = set(dict1.keys())\n    keys2 = set(dict2.keys())\n    shared_keys = keys1.intersection(keys2)\n    keys1_unique = keys1.difference(shared_keys)\n    keys2_unique = keys2.difference(shared_keys)\n    key_diff_list = list(keys1_unique) + list(keys2_unique)\n\n    # value difference in the shared keys in dict1 and dict2\n    value_diff_dict = {}\n    for k in shared_keys:\n        if dict1[k] != dict2[k]:\n            value_diff_dict[k] = [(dict1_name, dict1[k]), (dict2_name, dict2[k])]\n\n    if verbose:\n        LOGGER.info(\"=\" * 30 + \"key difference\")\n        LOGGER.info(f\"keys in {dict1_name} but not in {dict2_name}: \"\n                    f\"total {len(keys1_unique)}, {sorted(keys1_unique)}\")\n        LOGGER.info(f\"keys in {dict2_name} but not in {dict1_name}: \"\n                    f\"total {len(keys2_unique)}, {sorted(keys2_unique)}\")\n\n    if verbose and print_value_diff:\n\n        LOGGER.info(\"=\" * 30 + \"value difference\")\n        LOGGER.info(f\"{json.dumps(value_diff_dict, indent=4)}\")\n\n    return value_diff_dict, key_diff_list\n\n\ndef _to_cuda(state):\n    \"\"\" usually load from cpu checkpoint but need to load to cuda \"\"\"\n    if isinstance(state, torch.Tensor):\n        ret = state.cuda()  # assume propoerly set py torch.cuda.set_device\n        if 'Half' in state.type():\n            ret = ret.float()  # apex O2 requires it\n        return ret\n    elif isinstance(state, list):\n        new_state = [_to_cuda(t) for t in state]\n    elif isinstance(state, tuple):\n        new_state = tuple(_to_cuda(t) for t in state)\n    elif isinstance(state, dict):\n        new_state = {n: _to_cuda(t) for n, t in state.items()}\n    else:\n        return state\n    return new_state\n\n\ndef _to_cpu(state):\n    \"\"\" store in cpu to avoid GPU0 device, fp16 to save space \"\"\"\n    if isinstance(state, torch.Tensor):\n        ret = state.cpu()\n        if 'Float' in state.type():\n            ret = ret.half()\n        return ret\n    elif isinstance(state, list):\n        new_state = [_to_cpu(t) for t in state]\n    elif isinstance(state, tuple):\n        new_state = tuple(_to_cpu(t) for t in state)\n    elif isinstance(state, dict):\n        new_state = {n: _to_cpu(t) for n, t in state.items()}\n    else:\n        return state\n    return new_state\n\n\nclass TrainingRestorer(object):\n    \"\"\"ckpt_dict: a dict contains all optimizers/models\"\"\"\n    def __init__(self, opts, **ckpt_dict):\n        if exists(opts.output_dir):\n            restore_opts = json.load(open(\n                f'{opts.output_dir}/log/args.json', 'r'))\n            assert opts == edict(restore_opts)\n        # keep 2 checkpoints in case of corrupted\n        self.save_path = f'{opts.output_dir}/restore.pt'\n        self.backup_path = f'{opts.output_dir}/restore_backup.pt'\n        self.ckpt_dict = ckpt_dict\n        self.save_steps = opts.save_steps\n        self.amp = opts.fp16\n        # since saving to or loading from azure blob fails sometimes\n        self.max_save_load_trial = 10\n        if exists(self.save_path) or exists(self.backup_path):\n            LOGGER.info('found previous checkpoint. try to resume...')\n            # with retrial, as azure blob fails occasionally.\n            restore_trial = 0\n            while restore_trial < self.max_save_load_trial:\n                LOGGER.info(f\"TrainingRestorer restore trial NO. {restore_trial}\")\n                try:\n                    self.restore()\n                    break\n                except Exception as e:\n                    restore_trial += 1\n        else:\n            self.global_step = 0\n\n    def step(self):\n        self.global_step += 1\n        if self.global_step % self.save_steps == 0:\n            # with retrial, as azure blob fails occasionally.\n            save_trial = 0\n            while save_trial < self.max_save_load_trial:\n                LOGGER.info(f\"TrainingRestorer save trial NO. {save_trial}\")\n                try:\n                    self.save()\n                    break\n                except Exception as e:\n                    save_trial += 1\n\n    def save(self):\n        checkpoint_to_save = {'global_step': self.global_step}\n        for k in self.ckpt_dict:\n            checkpoint_to_save[k] = _to_cpu(self.ckpt_dict[k].state_dict())\n        if self.amp:\n            checkpoint_to_save['amp_state_dict'] = amp.state_dict()\n        if exists(self.save_path):\n            os.rename(self.save_path, self.backup_path)\n        torch.save(checkpoint_to_save, self.save_path)\n\n    def restore(self):\n        try:\n            checkpoint = torch.load(self.save_path)\n        except Exception:\n            checkpoint = torch.load(self.backup_path)\n        self.global_step = checkpoint['global_step']\n        for k in self.ckpt_dict:\n            self.ckpt_dict[k].load_state_dict(_to_cuda(checkpoint[k]))\n        if self.amp:\n            amp.load_state_dict(checkpoint['amp_state_dict'])\n        LOGGER.info(f'resume training from step {self.global_step}')\n\n\nclass E2E_TrainingRestorer(object):\n    def __init__(self, opts, model, optimizer):\n        if exists(f\"{opts.output_dir}/log/args.json\"):\n            restore_opts = json.load(\n                open(f'{opts.output_dir}/log/args.json', 'r'))\n            with open(join(\n                    opts.output_dir, 'log',\n                    'restore_args.json'), 'w') as writer:\n                json.dump(vars(opts), writer, indent=4)\n            # assert opts == edict(restore_opts)\n        # keep 2 checkpoints in case of corrupted\n        self.save_path = f'{opts.output_dir}/restore.pt'\n        self.backup_path = f'{opts.output_dir}/restore_backup.pt'\n        self.model = model\n        self.optimizer = optimizer\n        self.save_steps = int(opts.save_steps_ratio * opts.num_train_steps)\n        self.amp = opts.fp16\n        # since saving to or loading from azure blob fails sometimes\n        self.max_save_load_trial = 10\n        if exists(self.save_path) or exists(self.backup_path):\n            LOGGER.info('found previous checkpoint. try to resume...')\n            # with retrial, as azure blob fails occasionally.\n            restore_trial = 0\n            while restore_trial < self.max_save_load_trial:\n                LOGGER.info(f\"TrainingRestorer restore trial NO. {restore_trial}\")\n                try:\n                    self.restore(opts)\n                    break\n                except Exception as e:\n                    restore_trial += 1\n        else:\n            self.global_step = 0\n\n    def step(self):\n        self.global_step += 1\n        if self.global_step % self.save_steps == 0:\n            # with retrial, as azure blob fails occasionally.\n            save_trial = 0\n            while save_trial < self.max_save_load_trial:\n                LOGGER.info(f\"TrainingRestorer save trial NO. {save_trial}\")\n                try:\n                    self.save()\n                    break\n                except Exception as e:\n                    save_trial += 1\n\n    def save(self):\n        checkpoint = {'global_step': self.global_step,\n                      'model_state_dict': _to_cpu(self.model.state_dict()),\n                      'optim_state_dict': _to_cpu(self.optimizer.state_dict())}\n        if self.amp:\n            checkpoint['amp_state_dict'] = amp.state_dict()\n        if exists(self.save_path):\n            os.rename(self.save_path, self.backup_path)\n        torch.save(checkpoint, self.save_path)\n\n    def restore(self, opts):\n        try:\n            checkpoint = torch.load(self.save_path)\n        except Exception:\n            checkpoint = torch.load(self.backup_path)\n        self.global_step = checkpoint['global_step']\n        self.model.load_state_dict(_to_cuda(checkpoint['model_state_dict']))\n        self.optimizer.load_state_dict(\n            _to_cuda(checkpoint['optim_state_dict']))\n        if self.amp:\n            amp.load_state_dict(checkpoint['amp_state_dict'])\n        LOGGER.info(f'resume training from step {self.global_step}')\n"
  },
  {
    "path": "src/utils/logger.py",
    "content": "\"\"\"\nreferences: UNITER\n\"\"\"\n\nimport logging\nfrom tensorboardX import SummaryWriter\n\n\n_LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s'\n_DATE_FMT = '%m/%d/%Y %H:%M:%S'\nlogging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)\nLOGGER = logging.getLogger('__main__')  # this is the global logger\n\n\ndef add_log_to_file(log_path):\n    fh = logging.FileHandler(log_path)\n    formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT)\n    fh.setFormatter(formatter)\n    LOGGER.addHandler(fh)\n\n\nclass TensorboardLogger(object):\n    def __init__(self):\n        self._logger = None\n        self._global_step = 0\n\n    def create(self, path):\n        self._logger = SummaryWriter(path)\n\n    def noop(self, *args, **kwargs):\n        return\n\n    def step(self):\n        self._global_step += 1\n\n    @property\n    def global_step(self):\n        return self._global_step\n\n    @global_step.setter\n    def global_step(self, step):\n        self._global_step = step\n\n    def log_scalar_dict(self, log_dict, prefix=''):\n        \"\"\" log a dictionary of scalar values\"\"\"\n        if self._logger is None:\n            return\n        if prefix:\n            prefix = f'{prefix}_'\n        for name, value in log_dict.items():\n            if isinstance(value, dict):\n                self.log_scalar_dict(value, self._global_step,\n                                     prefix=f'{prefix}{name}')\n            else:\n                self._logger.add_scalar(f'{prefix}{name}', value,\n                                        self._global_step)\n\n    def __getattr__(self, name):\n        if self._logger is None:\n            return self.noop\n        return self._logger.__getattribute__(name)\n\n\nTB_LOGGER = TensorboardLogger()\n\n\nclass RunningMeter(object):\n    \"\"\" running meteor of a scalar value\n        (useful for monitoring training loss)\n    \"\"\"\n    def __init__(self, name, val=None, smooth=0.99):\n        self._name = name\n        self._sm = smooth\n        self._val = val\n\n    def __call__(self, value):\n        self._val = (value if self._val is None\n                     else value*(1-self._sm) + self._val*self._sm)\n\n    def __str__(self):\n        return f'{self._name}: {self._val:.4f}'\n\n    @property\n    def val(self):\n        return self._val\n\n    @property\n    def name(self):\n        return self._name\n"
  },
  {
    "path": "src/utils/misc.py",
    "content": "\"\"\"\nmodified from UNITER\n\"\"\"\nimport json\nimport random\nimport sys\n\nimport torch\nimport numpy as np\n\n\nclass NoOp(object):\n    \"\"\" useful for distributed training No-Ops \"\"\"\n    def __getattr__(self, name):\n        return self.noop\n\n    def noop(self, *args, **kwargs):\n        return\n\n\ndef set_random_seed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\ndef zero_none_grad(model):\n    for p in model.parameters():\n        if p.grad is None and p.requires_grad:\n            p.grad = p.data.new(p.size()).zero_()\n"
  }
]