Repository: salesforce/ALPRO
Branch: main
Commit: d21173f55a73
Files: 62
Total size: 462.2 KB
Directory structure:
gitextract_egeq4lt4/
├── .gitignore
├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING-ARCHIVED.md
├── LICENSE
├── README.md
├── SECURITY.md
├── config_release/
│ ├── base_model.json
│ ├── didemo_ret.json
│ ├── msrvtt_qa.json
│ ├── msrvtt_ret.json
│ ├── msvd_qa.json
│ ├── pretrain_alpro.json
│ ├── pretrain_prompter.json
│ ├── timesformer_divst_8x32_224_k600.json
│ └── timesformer_divst_8x32_224_k600_gc.json
├── env/
│ ├── install_pkg.sh
│ └── requirements.txt
├── run_scripts/
│ ├── clear_cuda_cache.sh
│ ├── ft_didemo_ret.sh
│ ├── ft_msrvtt_qa.sh
│ ├── ft_msrvtt_ret.sh
│ ├── ft_msvd_qa.sh
│ ├── inf_didemo_ret.sh
│ ├── inf_msrvtt_qa.sh
│ ├── inf_msrvtt_ret.sh
│ ├── inf_msvd_qa.sh
│ ├── pt_alpro.sh
│ └── pt_prompter.sh
└── src/
├── __init__.py
├── configs/
│ └── config.py
├── datasets/
│ ├── data_utils.py
│ ├── dataloader.py
│ ├── dataset_base.py
│ ├── dataset_pretrain_sparse.py
│ ├── dataset_video_qa.py
│ ├── dataset_video_retrieval.py
│ └── randaugment.py
├── modeling/
│ ├── alpro_models.py
│ ├── timesformer/
│ │ ├── __init__.py
│ │ ├── conv2d_same.py
│ │ ├── features.py
│ │ ├── helpers.py
│ │ ├── linear.py
│ │ ├── operators.py
│ │ ├── vit.py
│ │ └── vit_utils.py
│ ├── transformers.py
│ └── xbert.py
├── optimization/
│ ├── adamw.py
│ ├── sched.py
│ └── utils.py
├── pretrain/
│ ├── run_pretrain_contrastive_only.py
│ └── run_pretrain_sparse.py
├── tasks/
│ ├── run_video_qa.py
│ └── run_video_retrieval.py
└── utils/
├── basic_utils.py
├── distributed.py
├── grad_ckpt.py
├── load_save.py
├── logger.py
└── misc.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.vscode
# script
tmp_all/script/
# Philly-realted #
pt/
.ptconfig
# Project-related #
*/*results*/
*results*/
tmp*/
cache/*
*/cache*/
tmp*.py
*pickle
# compiled files #
*.pyc
# Packages #
############
# it's better to unpack these files and commit the raw source
# git has its own built in compression methods
*.7z
*.dmg
*.gz
*.iso
*.jar
*.rar
*.tar
*.zip
# Logs and databases #
######################
*.log
*.sql
*.sqlite
.ipynb_checkpoints/
*.swp
*.vscode/
*.idea/
# OS generated files #
######################
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# project-specific
img
txt
ext
data
output
src/configs_local
================================================
FILE: CODEOWNERS
================================================
# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
#ECCN:Open Source
================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Salesforce Open Source Community Code of Conduct
## About the Code of Conduct
Equality is a core value at Salesforce. We believe a diverse and inclusive
community fosters innovation and creativity, and are committed to building a
culture where everyone feels included.
Salesforce open-source projects are committed to providing a friendly, safe, and
welcoming environment for all, regardless of gender identity and expression,
sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
race, age, religion, level of experience, education, socioeconomic status, or
other similar personal characteristics.
The goal of this code of conduct is to specify a baseline standard of behavior so
that people with different social values and communication styles can work
together effectively, productively, and respectfully in our open source community.
It also establishes a mechanism for reporting issues and resolving conflicts.
All questions and reports of abusive, harassing, or otherwise unacceptable behavior
in a Salesforce open-source project may be reported by contacting the Salesforce
Open Source Conduct Committee at ossconduct@salesforce.com.
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of gender
identity and expression, sexual orientation, disability, physical appearance,
body size, ethnicity, nationality, race, age, religion, level of experience, education,
socioeconomic status, or other similar personal characteristics.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy toward other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Personal attacks, insulting/derogatory comments, or trolling
* Public or private harassment
* Publishing, or threatening to publish, others' private information—such as
a physical or electronic address—without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
* Advocating for or encouraging any of the above behaviors
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned with this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project email
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the Salesforce Open Source Conduct Committee
at ossconduct@salesforce.com. All complaints will be reviewed and investigated
and will result in a response that is deemed necessary and appropriate to the
circumstances. The committee is obligated to maintain confidentiality with
regard to the reporter of an incident. Further details of specific enforcement
policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership and the Salesforce Open Source Conduct
Committee.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
[golang-coc]: https://golang.org/conduct
[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
================================================
FILE: CONTRIBUTING-ARCHIVED.md
================================================
# ARCHIVED
This project is `Archived` and is no longer actively maintained;
We are not accepting contributions or Pull Requests.
================================================
FILE: LICENSE
================================================
Copyright (c) 2021, Salesforce.com, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# ALPRO (CVPR 22')
## ALPRO is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS), a one-stop library for language-vision intelligence!
## Align and Prompt: Video-and-Language Pre-training with Entity Prompts [[Paper](https://arxiv.org/abs/2112.09583)]
[Dongxu Li](https://www.linkedin.com/in/dongxu-li-a8a035110/), [Junnan Li](https://sites.google.com/site/junnanlics), [Hongdong Li](http://users.cecs.anu.edu.au/~hongdong/), [Juan Carlos Niebles](http://www.niebles.net/), [Steven C.H. Hoi](https://sites.google.com/view/stevenhoi/home)
Official PyTorch code for ALPRO. This repository supports pre-training as well as finetuning on
- Text-Video Retrieval on MSRVTT and DiDeMo.
- Video Question Anwsering on MSRVTT and MSVD.
## Requirements
Our implementation is tested on Ubuntu 20.04.1 with NVIDIA A100 GPUs. Supports for other platforms and hardwares are possible with no warrant. To install the required packages:
```bash
cd env && bash install_pkg.sh
```
## Data Preparation
1. Download Annotations and Pre-trained Checkpoints
- [Text annotations](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/data.zip)
- [Checkpoints of pre-trained model and finetuned model](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/output.zip)
- [Externel resources](https://storage.googleapis.com/sfr-vision-language-research/ALPRO/ext.zip)
- unzip `data.zip`, `output.zip`, `ext.zip` under `ALPRO/`.
2. Download raw videos of downstream datasets.
- MSRVTT:
- download train_val_videos.zip and test_videos.zip from e.g. [here](https://www.mediafire.com/folder/h14iarbs62e7p/shared).
- check md5sum:
```bash
51f2394d279cf84f1642defd9a651e6f train_val_videos.zip
0af68454cec9d586e92805739f3911d0 test_videos.zip
```
- unzip all the videos into `data/msrvtt_ret/videos` (10k in total).
- create the following soft link:
```bash
ln -s data/msrvtt_ret/videos data/msrvtt_qa/videos```
- MSVD:
- download from official release:
```bash
wget -nc https://www.cs.utexas.edu/users/ml/clamp/videoDescription/YouTubeClips.tar
```
- check md5sum:
```bash
9bdb20fcf14d59524a6febca9f6a8d89 YouTubeClips.tar
```
- unzip all the videos to `data/msvd_qa/videos` (1,970 videos in total).
```bash
mkdir data/msvd_qa/videos/
tar xvf YouTubeClips.tar -C data/msvd_qa/videos --strip-components=1
```
- DiDeMo:
- Following [instructions](https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md) and download from the official release [here](https://drive.google.com/drive/u/1/folders/1_oyJ5rQiZboipbMl6tkhY8v0s9zDkvJc);
- unzip all the videos into `data/didemo_ret/videos`.
- Note there might be a couple videos missing. See [here](https://github.com/LisaAnne/LocalizingMoments/blob/master/README.md#getting-the-videos) to download. However, as they account for a small portion of training set, you may feel safe to ignore.
- Convert all the DiDeMo videos into `*.mp4` format using e.g. [`ffmpeg`](https://askubuntu.com/questions/396883/how-to-simply-convert-video-files-i-e-mkv-to-mp4).
- We obtained 10,463 videos following these steps (with one video `77807177@N00_5753455690_1e04ccb364` missing).
3. The directory is expected to be in the structure below:
```bash
.
|-config_release # configuration files
|-data # text annotations and raw videos
|---didemo_ret
|-----txt
|-----videos
|---msrvtt_qa/...
|---msrvtt_ret/...
|---msvd_qa/...
|-env # scripts to install packages
|-ext # external resources, e.g. bert tokenizer
|-output # checkpoints for pre-trained/finetuned models
|---downstreams
|-----didemo_ret
|-------public
|---------ckpt # official finetuned checkpoints
|---------log # inference log
|---------results_test
|-----------step_best_1_mean
|-----msrvtt_qa/...
|-----msrvtt_ret/...
|-----msvd_qa/...
|-run_scripts # bash scripts to launch experiments
|-src # source code
```
## Inference with Official Checkpoints
```bash
cd run_scripts
bash inf_msrvtt_ret.sh
# {'text2video': {'r1': 33.9, 'r5': 60.7, 'r10': 73.2, 'medianR': 3.0, 'meanR': 27.404}}
bash inf_didemo_ret.sh
# {'text2video': {'r1': 35.9, 'r5': 67.5, 'r10': 78.8, 'medianR': 3.0, 'meanR': 19.125}}
bash inf_msrvtt_qa.sh
# {'ratios': {'what_ratio': [68.48, 49872], 'who_ratio': [27.99, 20385], 'how_ratio': [2.25, 1640], 'where_ratio': [0.34, 250], 'when_ratio': [0.93, 677]}, 'overall_acc': 42.12, 'what_acc': 36.05, 'who_acc': 52.24, 'how_acc': 85.67, 'where_acc': 42.8, 'when_acc': 78.88}
bash inf_msvd_qa.sh
# {'ratios': {'what_ratio': [61.93, 8150], 'who_ratio': [34.6, 4554], 'how_ratio': [2.81, 370], 'where_ratio': [0.21, 28], 'when_ratio': [0.44, 58]}, 'overall_acc': 45.91, 'what_acc': 37.02, 'who_acc': 58.59, 'how_acc': 81.62, 'where_acc': 46.43, 'when_acc': 72.41}
```
## Downstream Task Finetuning
- To finetune on downstream tasks with the pre-trained checkpoint `output/pretrain/alpro_pretrained_ckpt.pt`
```bash
cd run_scripts
bash ft_msrvtt_ret.sh
bash ft_didemo_ret.sh
bash ft_msrvtt_qa.sh
bash ft_msvd_qa.sh
```
For example, with MSRVTT retrieval:
```bash
cd ALPRO/
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
CONFIG_PATH='config_release/msrvtt_ret.json'
horovodrun -np 8 python src/tasks/run_video_retrieval.py \ # change -np to GPUs numbers.
--config $CONFIG_PATH \
--output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_ret/$(date '+%Y%m%d%H%M%S') # change to your local path to store finetuning ckpts and logs
```
- Run inference with locally-finetuned checkpoints.
```bash
cd ALPRO/
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
STEP='best'
CONFIG_PATH='config_release/msrvtt_ret.json'
OUTPUT_DIR='[INPUT_YOUR_OUTPUT_PATH_HERE]'
TXT_DB='data/msrvtt_ret/txt/test.jsonl'
IMG_DB='data/msrvtt_ret/videos'
horovodrun -np 8 python src/tasks/run_video_retrieval.py \
--do_inference 1 \
--inference_split test \
--inference_model_step $STEP \
--inference_txt_db $TXT_DB \
--inference_img_db $IMG_DB \
--inference_batch_size 64 \
--output_dir $OUTPUT_DIR \
--config $CONFIG_PATH
```
- `OUTPUT_DIR` is the path after the `--output_dir` option in the finetuning script.
- `$STEP` is a string, which tells the script to use the checkpoint `$OUTPUT_DIR/ckpt/model_step_$STEP.pt` for inference.
## Pretraining
1. Download [WebVid2M](https://github.com/m-bain/frozen-in-time) and [CC-3M](https://github.com/igorbrigadir/DownloadConceptualCaptions).
- Put WebVid2M videos under `data/webvid2m`;
- 💡 we downsample webvid2m videos to 10% of the original FPS to speed-up video loading;
- change `data/cc3m/txt/cc3m.json` with local image paths.
2. Training Prompter:
```bash
cd run_scripts && bash pt_prompter.sh
```
3. Training video-language model:
```bash
cd run_scripts && bash pt_alpro.sh
```
If you would like to use custom prompter weight, please change `teacher_weights_path` in `config_release/pretrain_alpro.json`
4. To finetune with pre-trained checkpoints, please change `e2e_weights_path` in the finetuning config files, e.g. `config_release/msrvtt_ret.json`.
## Citation
If you find ALPRO useful for your research, please consider citing:
```bibtex
@inproceedings{li2021align,
title={Align and Prompt: Video-and-Language Pre-training with Entity Prompts},
author={Dongxu Li, Junnan Li, Hongdong Li, Juan Carlos Niebles, Steven C.H. Hoi},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
```
## Acknowledgement
We thank members at Salesforce Research for their helpful discussions.
The implementation of ALPRO relies on resources from [ClipBERT](https://github.com/jayleicn/ClipBERT),
[transformers](https://github.com/huggingface/transformers),
[TimeSformer](https://github.com/facebookresearch/TimeSformer/tree/main/timesformer/models),
The code is implemented using [PyTorch](https://github.com/pytorch/pytorch),
with multi-GPU support from [Horovod](https://github.com/horovod/horovod) and [gradient-checkpoint](https://github.com/csrhddlam/pytorch-checkpoint). We thank the original authors for their open-sourcing and encourage ALPRO users to cite their works when applicable.
================================================
FILE: SECURITY.md
================================================
## Security
Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
as soon as it is discovered. This library limits its runtime dependencies in
order to reduce the total cost of ownership as much as can be, but all consumers
should remain vigilant and have their security stakeholders review all third-party
products (3PP) like this one and their dependencies.
================================================
FILE: config_release/base_model.json
================================================
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30522,
"fusion_layer": 6,
"encoder_width": 768,
"itc_token_type": "cls"
}
================================================
FILE: config_release/didemo_ret.json
================================================
{
"train_datasets": [
{
"name": "didemo",
"txt": "data/didemo_ret/txt/train.jsonl",
"img": "data/didemo_ret/videos"
}
],
"val_datasets": [
{
"name": "didemo_retrieval",
"txt": "data/didemo_ret/txt/val.jsonl",
"img": "data/didemo_ret/videos"
}
],
"max_txt_len": 50,
"crop_img_size": 224,
"resize_size": 256,
"img_pixel_mean": [0.48145466, 0.4578275, 0.40821073],
"img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
"img_input_format": "RGB",
"num_frm": 8,
"train_n_clips": 1,
"max_n_example_per_group": 1,
"model_config": "config_release/base_model.json",
"tokenizer_dir": "ext/bert-base-uncased/",
"visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json",
"e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt",
"bert_weights_path": null,
"train_batch_size": 12,
"val_batch_size": 12,
"gradient_accumulation_steps": 1,
"num_train_epochs": 10,
"min_valid_steps": 20,
"num_valid": 20,
"learning_rate": 4e-5,
"weight_decay": 1e-3,
"decay": "linear",
"optim": "adamw",
"betas": [0.9, 0.98],
"dropout": 0.1,
"grad_norm": 20.0,
"seed":42,
"fp16": 0,
"num_workers": 4
}
================================================
FILE: config_release/msrvtt_qa.json
================================================
{
"train_datasets": [
{
"name": "msrvtt_qa",
"txt": {
"msrvtt_qa": "data/msrvtt_qa/txt/train.jsonl"
},
"img": "data/msrvtt_qa/videos"
}
],
"val_datasets": [
{
"name": "msrvtt_qa",
"txt": {
"msrvtt_qa": "data/msrvtt_qa/txt/val.jsonl"
},
"img": "data/msrvtt_qa/videos"
}
],
"ans2label_path": "data/msrvtt_qa/txt/train_ans2label.json",
"max_txt_len": 40,
"crop_img_size": 224,
"resize_size": 256,
"img_pixel_mean": [0.48145466, 0.4578275, 0.40821073],
"img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
"img_input_format": "RGB",
"train_n_clips": 1,
"model_config": "config_release/base_model.json",
"tokenizer_dir": "ext/bert-base-uncased/",
"visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600_gc.json",
"e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt",
"num_frm": 16,
"train_batch_size": 12,
"val_batch_size": 12,
"gradient_accumulation_steps": 2,
"num_train_epochs": 10,
"min_valid_steps": 50,
"num_valid": 50,
"learning_rate": 5e-5,
"weight_decay": 1e-3,
"decay": "linear",
"optim": "adamw",
"betas": [0.9, 0.98],
"dropout": 0.1,
"grad_norm": 5.0,
"cnn_lr_decay": "linear",
"seed":42,
"fp16": 0,
"classifier": "mlp",
"cls_hidden_scale": 2,
"task": "msrvtt_qa",
"num_workers": 4
}
================================================
FILE: config_release/msrvtt_ret.json
================================================
{
"train_datasets": [
{
"name": "msrvtt",
"txt": "data/msrvtt_ret/txt/train.jsonl",
"img": "data/msrvtt_ret/videos"
}
],
"val_datasets": [
{
"name": "msrvtt_retrieval",
"txt": "data/msrvtt_ret/txt/val.jsonl",
"img": "data/msrvtt_ret/videos"
}
],
"max_txt_len": 40,
"crop_img_size": 224,
"resize_size": 256,
"img_pixel_mean": [0.48145466, 0.4578275, 0.40821073],
"img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
"img_input_format": "RGB",
"train_n_clips": 1,
"model_config": "config_release/base_model.json",
"tokenizer_dir": "ext/bert-base-uncased/",
"visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json",
"e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt",
"num_frm": 8,
"train_batch_size": 8,
"val_batch_size": 8,
"gradient_accumulation_steps": 1,
"num_train_epochs": 5,
"min_valid_steps": 100,
"num_valid": 20,
"learning_rate": 2.5e-5,
"weight_decay": 1e-3,
"decay": "linear",
"optim": "adamw",
"betas": [0.9, 0.98],
"dropout": 0.1,
"grad_norm": 5.0,
"seed":42,
"fp16": 0,
"num_workers": 4
}
================================================
FILE: config_release/msvd_qa.json
================================================
{
"train_datasets": [
{
"name": "msvd_qa",
"txt": {
"msvd_qa": "data/msvd_qa/txt/train.jsonl"
},
"img": "data/msvd_qa/videos"
}
],
"val_datasets": [
{
"name": "msvd_qa",
"txt": {
"msvd_qa": "data/msvd_qa/txt/val.jsonl"
},
"img": "data/msvd_qa/videos"
}
],
"ans2label_path": "data/msvd_qa/txt/train_ans2label.json",
"num_labels": 2423,
"max_txt_len": 40,
"crop_img_size": 224,
"resize_size": 256,
"img_pixel_mean": [0.48145466, 0.4578275, 0.40821073],
"img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
"img_input_format": "RGB",
"train_n_clips": 1,
"num_frm": 16,
"model_config": "config_release/base_model.json",
"tokenizer_dir": "ext/bert-base-uncased/",
"visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600_gc.json",
"e2e_weights_path": "output/pretrain/alpro_pretrained_ckpt.pt",
"train_batch_size": 12,
"val_batch_size": 12,
"gradient_accumulation_steps": 2,
"num_train_epochs": 15,
"min_valid_steps": 50,
"num_valid": 30,
"learning_rate": 5e-5,
"weight_decay": 1e-3,
"decay": "linear",
"optim": "adamw",
"betas": [0.9, 0.98],
"dropout": 0.1,
"grad_norm": 20.0,
"cnn_lr_decay": "linear",
"seed":42,
"fp16": 0,
"save_steps_ratio": 0.05,
"classifier": "mlp",
"cls_hidden_scale": 2,
"task": "msvd_qa",
"num_workers": 4
}
================================================
FILE: config_release/pretrain_alpro.json
================================================
{
"train_datasets": [
{
"name": "webvid2m",
"ann": "data/webvid2m/txt/train.pkl",
"txt": null,
"img": "data/webvid2m/videos"
},
{
"name": "cc3m",
"ann": "data/cc3m/txt/cc3m.json",
"txt": null,
"img": null
}
],
"val_datasets": [
{
"name": "webvid2m",
"ann": "data/webvid2m/txt/val.pkl",
"txt": null,
"img": "data/webvid2m/videos"
}
],
"img_pixel_mean": [0.48145466, 0.4578275, 0.40821073],
"img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
"img_input_format": "RGB",
"model_type": "pretrain",
"model_config": "config_release/base_model.json",
"visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json",
"visual_weights_path": "vit_base_patch16_224",
"teacher_weights_path": "output/pretrain/prompter_pretrained.pt",
"entity_file_path": "data/unigrams.txt",
"tokenizer_dir": "ext/bert-base-uncased/",
"max_txt_len": 30,
"crop_img_size": 224,
"resize_size": 256,
"train_batch_size": 16,
"val_batch_size": 16,
"gradient_accumulation_steps": 1,
"num_train_epochs": 10,
"min_valid_steps": 10,
"num_valid": 10,
"learning_rate": 1e-4,
"decay": "linear",
"optim": "adamw",
"betas": [0.9, 0.98],
"dropout": 0.1,
"weight_decay": 1e-3,
"grad_norm": 20.0,
"seed":42,
"fp16": 0,
"use_itm": 1,
"use_mlm": 1,
"use_itc": 1,
"use_mpm": 1,
"n_workers": 4,
"save_steps_ratio": 0.01,
"frm_sampling_strategy": "headtail",
"num_frm": 4,
"fps": 0.5,
"debug": false,
"warmup_ratio": 0.05,
"log_interval": 100
}
================================================
FILE: config_release/pretrain_prompter.json
================================================
{
"train_datasets": [
{
"name": "webvid2m",
"ann": "data/webvid2m/txt/train.pkl",
"txt": null,
"img": "data/webvid2m/videos"
},
{
"name": "cc3m",
"ann": "data/cc3m/txt/cc3m.json",
"txt": null,
"img": null
}
],
"val_datasets": [
{
"name": "webvid2m",
"ann": "data/webvid2m/txt/val.pkl",
"txt": null,
"img": "data/webvid2m/videos"
}
],
"img_pixel_mean": [0.48145466, 0.4578275, 0.40821073],
"img_pixel_std": [0.26862954, 0.26130258, 0.27577711],
"img_input_format": "RGB",
"model_type": "pretrain",
"model_config": "config_release/base_model.json",
"visual_model_cfg": "config_release/timesformer_divst_8x32_224_k600.json",
"visual_weights_path": "vit_base_patch16_224",
"tokenizer_dir": "ext/bert-base-uncased/",
"max_txt_len": 30,
"crop_img_size": 224,
"resize_size": 256,
"train_batch_size": 16,
"val_batch_size": 16,
"gradient_accumulation_steps": 2,
"num_train_epochs": 10,
"min_valid_steps": 100,
"num_valid": 10,
"learning_rate": 1e-4,
"decay": "linear",
"optim": "adamw",
"betas": [0.9, 0.98],
"dropout": 0.1,
"weight_decay": 1e-3,
"grad_norm": 20.0,
"seed":42,
"fp16": 0,
"use_itm": 0,
"use_mlm": 0,
"use_itc": 1,
"n_workers": 4,
"save_steps_ratio": 0.05,
"frm_sampling_strategy": "headtail",
"num_frm": 4,
"debug": false,
"warmup_ratio": 0.05,
"log_interval": 100
}
================================================
FILE: config_release/timesformer_divst_8x32_224_k600.json
================================================
{
"cls": "TimeSformer",
"patch_size": 16,
"attn_drop_rate": 0,
"drop_rate": 0,
"drop_path_rate": 0.1,
"maxpool_kernel_size": 2,
"use_maxpooling": false,
"gradient_checkpointing": false
}
================================================
FILE: config_release/timesformer_divst_8x32_224_k600_gc.json
================================================
{
"cls": "TimeSformer",
"patch_size": 16,
"attn_drop_rate": 0,
"drop_rate": 0,
"drop_path_rate": 0.1,
"maxpool_kernel_size": 2,
"use_maxpooling": false,
"gradient_checkpointing": true
}
================================================
FILE: env/install_pkg.sh
================================================
apt update
apt install lsof
# horovod
HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_NCCL_LINK=SHARED HOROVOD_WITH_PYTORCH=1 \
pip install --no-cache-dir horovod==0.19.4 &&\
ldconfig
# use the faster pillow-simd instead of the original pillow
# https://github.com/uploadcare/pillow-simd
pip uninstall pillow && \
CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
spacy download en
pip install -r requirements.txt
git clone https://github.com/NVIDIA/apex.git &&\
cd apex &&\
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . &&\
rm -rf ../apex
================================================
FILE: env/requirements.txt
================================================
ipdb
joblib
cytoolz
lz4==2.1.9
lmdb==0.97
msgpack-numpy
msgpack
toolz
transformers==4.11.3
tensorboard
tqdm
easydict
pycocotools>=2.0.1
opencv-python
tensorboardX==2.0
av==8.0.2
ujson
einops
decord
timm
================================================
FILE: run_scripts/clear_cuda_cache.sh
================================================
for i in $(lsof /dev/nvidia* | grep python | awk '{print $2}' | sort -u); do kill -9 $i; done
================================================
FILE: run_scripts/ft_didemo_ret.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
CONFIG_PATH='config_release/didemo_ret.json'
horovodrun -np 8 python src/tasks/run_video_retrieval.py \
--config $CONFIG_PATH \
--output_dir /export/home/workspace/experiments/alpro/finetune/didemo_ret/$(date '+%Y%m%d%H%M%S')
================================================
FILE: run_scripts/ft_msrvtt_qa.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
CONFIG_PATH='config_release/msrvtt_qa.json'
horovodrun -np 8 python src/tasks/run_video_qa.py \
--config $CONFIG_PATH \
--output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_qa/$(date '+%Y%m%d%H%M%S')
================================================
FILE: run_scripts/ft_msrvtt_ret.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
CONFIG_PATH='config_release/msrvtt_ret.json'
horovodrun -np 8 python src/tasks/run_video_retrieval.py \
--config $CONFIG_PATH \
--output_dir /export/home/workspace/experiments/alpro/finetune/msrvtt_ret/$(date '+%Y%m%d%H%M%S')
================================================
FILE: run_scripts/ft_msvd_qa.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
CONFIG_PATH='config_release/msvd_qa.json'
horovodrun -np 8 python src/tasks/run_video_qa.py \
--config $CONFIG_PATH \
--output_dir /export/home/workspace/experiments/alpro/finetune/msvd_qa/$(date '+%Y%m%d%H%M%S')
================================================
FILE: run_scripts/inf_didemo_ret.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
STEP='best'
CONFIG_PATH='config_release/didemo_ret.json'
TXT_DB='data/didemo_ret/txt/test.jsonl'
IMG_DB='data/didemo_ret/videos'
horovodrun -np 8 python src/tasks/run_video_retrieval.py \
--do_inference 1 \
--inference_split test \
--inference_model_step $STEP \
--inference_txt_db $TXT_DB \
--inference_img_db $IMG_DB \
--inference_batch_size 64 \
--output_dir output/downstreams/didemo_ret/public \
--config $CONFIG_PATH
================================================
FILE: run_scripts/inf_msrvtt_qa.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
STEP='best'
CONFIG_PATH='config_release/msrvtt_qa.json'
TXT_DB='data/msrvtt_qa/txt/test.jsonl'
IMG_DB='data/msrvtt_qa/videos'
horovodrun -np 8 python src/tasks/run_video_qa.py \
--do_inference 1 \
--inference_split test \
--inference_model_step $STEP \
--inference_txt_db $TXT_DB \
--inference_img_db $IMG_DB \
--inference_batch_size 64 \
--output_dir output/downstreams/msrvtt_qa/public \
--config $CONFIG_PATH
================================================
FILE: run_scripts/inf_msrvtt_ret.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
STEP='best'
CONFIG_PATH='config_release/msrvtt_ret.json'
TXT_DB='data/msrvtt_ret/txt/test.jsonl'
IMG_DB='data/msrvtt_ret/videos'
horovodrun -np 8 python src/tasks/run_video_retrieval.py \
--do_inference 1 \
--inference_split test \
--inference_model_step $STEP \
--inference_txt_db $TXT_DB \
--inference_img_db $IMG_DB \
--inference_batch_size 64 \
--output_dir output/downstreams/msrvtt_ret/public \
--config $CONFIG_PATH
================================================
FILE: run_scripts/inf_msvd_qa.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
STEP='best'
CONFIG_PATH='config_release/msvd_qa.json'
TXT_DB='data/msvd_qa/txt/test.jsonl'
IMG_DB='data/msvd_qa/videos'
horovodrun -np 8 python src/tasks/run_video_qa.py \
--do_inference 1 \
--inference_split test \
--inference_model_step $STEP \
--inference_txt_db $TXT_DB \
--inference_img_db $IMG_DB \
--inference_batch_size 64 \
--output_dir output/downstreams/msvd_qa/public \
--config $CONFIG_PATH
================================================
FILE: run_scripts/pt_alpro.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
CONFIG_PATH='config_release/pretrain_alpro.json'
horovodrun -np 16 python src/pretrain/run_pretrain_sparse.py \
--config $CONFIG_PATH \
--output_dir /export/home/workspace/experiments/alpro/vl/$(date '+%Y%m%d%H%M%S')
================================================
FILE: run_scripts/pt_prompter.sh
================================================
cd ..
export PYTHONPATH="$PYTHONPATH:$PWD"
echo $PYTHONPATH
CONFIG_PATH='config_release/pretrain_prompter.json'
horovodrun -np 8 python src/pretrain/run_pretrain_contrastive_only.py \
--config $CONFIG_PATH \
--output_dir /export/home/workspace/experiments/alpro/prompter/$(date '+%Y%m%d%H%M%S')
================================================
FILE: src/__init__.py
================================================
================================================
FILE: src/configs/config.py
================================================
"""
Modified from UNITER code
"""
import os
import sys
import json
import argparse
from easydict import EasyDict as edict
def parse_with_config(parsed_args):
"""This function will set args based on the input config file.
(1) it only overwrites unset parameters,
i.e., these parameters not set from user command line input
(2) it also sets configs in the config file but declared in the parser
"""
# convert to EasyDict object, enabling access from attributes even for nested config
# e.g., args.train_datasets[0].name
args = edict(vars(parsed_args))
if args.config is not None:
config_args = json.load(open(args.config))
override_keys = {arg[2:].split("=")[0] for arg in sys.argv[1:]
if arg.startswith("--")}
for k, v in config_args.items():
if k not in override_keys:
setattr(args, k, v)
del args.config
return args
class SharedConfigs(object):
"""Shared options for pre-training and downstream tasks.
For each downstream task, implement a get_*_args function,
see `get_pretraining_args()`
Usage:
>>> shared_configs = SharedConfigs()
>>> pretraining_config = shared_configs.get_pretraining_args()
"""
def __init__(self, desc="shared config for pretraining and finetuning"):
parser = argparse.ArgumentParser(description=desc)
# debug parameters
parser.add_argument(
"--debug", type=int, choices=[0, 1], default=0,
help="debug mode, output extra info & break all loops."
"0: disable, 1 enable")
parser.add_argument(
"--data_ratio", type=float, default=1.0,
help="portion of train/val exampels to use,"
"e.g., overfit a small set of data")
# Required parameters
parser.add_argument(
"--model_config", type=str,
help="path to model structure config json")
parser.add_argument(
"--tokenizer_dir", type=str, help="path to tokenizer dir")
parser.add_argument(
"--output_dir", type=str,
help="dir to store model checkpoints & training meta.")
# data preprocessing parameters
parser.add_argument(
"--max_txt_len", type=int, default=20, help="max text #tokens ")
# parser.add_argument(
# "--max_img_size", type=int, default=448,
# help="max image longer side size, shorter side will be padded with zeros")
parser.add_argument(
"--img_pixel_mean", type=float, default=None,
nargs=3, help="image pixel mean")
parser.add_argument(
"--img_pixel_std", type=float, default=None,
nargs=3, help="image pixel std")
parser.add_argument(
"--img_input_format", type=str, default="BGR",
choices=["BGR", "RGB"], help="image input format is BGR for detectron2")
parser.add_argument(
"--max_n_example_per_group", type=int, default=1,
help="max #examples (e.g., captions) paired with each image/video in an input group."
"1: each image is paired with a single sent., equivalent to sample by sent.;"
"X (X>1): each image can be paired with a maximum of X sent.; X>1 can be used "
"to reduce image processing time, including basic transform (resize, etc) and CNN encoding"
)
# video specific parameters
parser.add_argument("--fps", type=int, default=1, help="video frame rate to use")
parser.add_argument("--num_frm", type=int, default=3,
help="#frames to use per clip -- we first sample a clip from a video, "
"then uniformly sample num_frm from the clip. The length of the clip "
"will be fps * num_frm")
parser.add_argument("--frm_sampling_strategy", type=str, default="rand",
choices=["rand", "uniform", "start", "middle", "end"],
help="see src.datasets.dataset_base.extract_frames_from_video_binary for details")
# MLL training settings
parser.add_argument("--train_n_clips", type=int, default=3,
help="#clips to sample from each video for MIL training")
parser.add_argument("--score_agg_func", type=str, default="mean",
choices=["mean", "max", "lse"],
help="score (from multiple clips) aggregation function, lse = LogSumExp")
parser.add_argument("--random_sample_clips", type=int, default=1, choices=[0, 1],
help="randomly sample clips for training, otherwise use uniformly sampled clips.")
# training parameters
parser.add_argument(
"--train_batch_size", default=128, type=int,
help="Single-GPU batch size for training for Horovod.")
parser.add_argument(
"--val_batch_size", default=128, type=int,
help="Single-GPU batch size for validation for Horovod.")
parser.add_argument(
"--gradient_accumulation_steps", type=int, default=1,
help="#updates steps to accumulate before performing a backward/update pass."
"Used to simulate larger batch size training. The simulated batch size "
"is train_batch_size * gradient_accumulation_steps for a single GPU.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="initial learning rate.")
parser.add_argument(
"--log_interval", default=500, type=int,
help="record every a few steps on tensorboard.")
parser.add_argument(
"--num_valid", default=20, type=int,
help="Run validation X times during training and checkpoint.")
parser.add_argument(
"--min_valid_steps", default=100, type=int,
help="minimum #steps between two validation runs")
parser.add_argument(
"--save_steps_ratio", default=0.01, type=float,
help="save every 0.01*global steps to resume after preemption,"
"not used for checkpointing.")
parser.add_argument("--num_train_epochs", default=10, type=int,
help="Total #training epochs.")
parser.add_argument("--optim", default="adamw",
choices=["adam", "adamax", "adamw"],
help="optimizer")
parser.add_argument("--betas", default=[0.9, 0.98],
nargs=2, help="beta for adam optimizer")
parser.add_argument("--decay", default="linear",
choices=["linear", "invsqrt"],
help="learning rate decay method")
parser.add_argument("--dropout", default=0.1, type=float,
help="tune dropout regularization")
parser.add_argument("--weight_decay", default=1e-3, type=float,
help="weight decay (L2) regularization")
parser.add_argument("--grad_norm", default=2.0, type=float,
help="gradient clipping (-1 for no clipping)")
parser.add_argument(
"--warmup_ratio", default=0.1, type=float,
help="to perform linear learning rate warmup for. (invsqrt decay)")
parser.add_argument("--transformer_lr_mul", default=1.0, type=float,
help="lr_mul for transformer")
parser.add_argument("--step_decay_epochs", type=int,
nargs="+", help="multi_step decay epochs")
# model arch
parser.add_argument(
"--model_type", type=str, default="pretrain",
help="type of e2e model to use. Support only 'pretrain' for now. ")
parser.add_argument(
"--timesformer_model_cfg", type=str, default="",
help="path to timesformer model cfg yaml")
# checkpoint
parser.add_argument("--e2e_weights_path", type=str,
help="path to e2e model weights")
parser.add_argument(
"--clip_init", default=0, type=int, choices=[0, 1],
help="1 for using clip ckpt for init.")
parser.add_argument("--bert_weights_path", type=str,
help="path to BERT weights, only use for pretraining")
# inference only, please include substring `inference'
# in the option to avoid been overwrite by loaded options,
# see start_inference() in run_vqa_w_hvd.py
parser.add_argument("--inference_model_step", default=-1, type=str,
help="pretrained model checkpoint step")
parser.add_argument(
"--do_inference", default=0, type=int, choices=[0, 1],
help="perform inference run. 0: disable, 1 enable")
parser.add_argument(
"--inference_split", default="val",
help="For val, the data should have ground-truth associated it."
"For test*, the data comes with no ground-truth.")
parser.add_argument("--inference_txt_db", type=str,
help="path to txt_db file for inference")
parser.add_argument("--inference_img_db", type=str,
help="path to img_db file for inference")
parser.add_argument("--inference_batch_size", type=int, default=64,
help="single-GPU batch size for inference")
parser.add_argument("--inference_n_clips", type=int, default=1,
help="uniformly sample `ensemble_n_clips` clips, "
"each contains `num_frm` frames. When it == 1, "
"use the frm_sampling_strategy to sample num_frm frames."
"When it > 1, ignore frm_sampling_strategy, "
"uniformly sample N clips, each clips num_frm frames.")
# device parameters
parser.add_argument("--seed", type=int, default=42,
help="random seed for initialization")
parser.add_argument(
"--fp16", type=int, choices=[0, 1], default=0,
help="Use 16-bit float precision instead of 32-bit."
"0: disable, 1 enable")
parser.add_argument("--n_workers", type=int, default=4,
help="#workers for data loading")
parser.add_argument("--pin_mem", type=int, choices=[0, 1], default=1,
help="pin memory. 0: disable, 1 enable")
# can use config files, will only overwrite unset parameters
parser.add_argument("--config", help="JSON config files")
self.parser = parser
def parse_args(self):
parsed_args = self.parser.parse_args()
args = parse_with_config(parsed_args)
# convert to all [0, 1] options to bool, including these task specific ones
zero_one_options = [
"fp16", "pin_mem", "use_itm", "use_mlm", "use_itc", "debug", #"freeze_cnn",
"do_inference",
]
for option in zero_one_options:
if hasattr(args, option):
setattr(args, option, bool(getattr(args, option)))
# basic checks
# This is handled at TrainingRestorer
# if exists(args.output_dir) and os.listdir(args.output_dir):
# raise ValueError(f"Output directory ({args.output_dir}) "
# f"already exists and is not empty.")
if args.step_decay_epochs and args.decay != "multi_step":
Warning(
f"--step_decay_epochs epochs set to {args.step_decay_epochs}"
f"but will not be effective, as --decay set to be {args.decay}")
assert args.gradient_accumulation_steps >= 1, \
f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps} "
assert 1 >= args.data_ratio > 0, \
f"--data_ratio should be [1.0, 0), but get {args.data_ratio}"
return args
def get_sparse_pretraining_args(self):
# pre-training args
self.parser.add_argument(
"--use_itm", type=int, choices=[0, 1], default=0,
help="enable itm loss. 0: disable, 1 enable")
self.parser.add_argument(
"--use_mlm", type=int, choices=[0, 1], default=0,
help="enable mlm loss. 0: disable, 1 enable")
self.parser.add_argument(
"--use_itc", type=int, choices=[0, 1], default=0,
help="enable itc loss. 0: disable, 1 enable")
# sparse pretraining-specific settings
self.parser.add_argument(
"--crop_img_size", type=int, default=256,
help="crop size during pre-training.")
self.parser.add_argument(
"--resize_size", type=int, default=288,
help="resize frames to square, ignoring aspect ratio.")
# MPM-specific
self.parser.add_argument(
"--use_mpm", type=int, choices=[0, 1], default=0,
help="enable mpm loss. 0: disable, 1 enable")
self.parser.add_argument("--teacher_weights_path", type=str,
help="path to teacher model weights, only use for pretraining.")
self.parser.add_argument("--entity_file_path", type=str,
help="path to selected NOUN entities.")
self.parser.add_argument(
"--num_entities", type=int, default=1000,
help="maximum entities to consider for pseudo labels.")
args = self.parse_args()
return args
def get_video_retrieval_args(self):
self.parser.add_argument("--eval_retrieval_batch_size", type=int, default=256,
help="batch size for retrieval, since each batch will only have one image, "
"retrieval allows larger batch size")
args = self.parse_args()
return args
def get_nlvl_args(self):
args = self.parse_args()
return args
def get_vqa_args(self):
self.parser.add_argument("--ans2label_path", type=str,
help="path to {answer: label} file")
self.parser.add_argument("--loss_type", type=str, default="bce",
help="loss type")
self.parser.add_argument("--classifier", type=str, default="mlp",
choices=["mlp", "linear"],
help="classifier type")
self.parser.add_argument(
"--cls_hidden_scale", type=int, default=2,
help="scaler of the intermediate linear layer dimension for mlp classifier")
self.parser.add_argument("--num_labels", type=int, default=3129,
help="#labels/output-dim for classifier")
return self.parse_args()
def get_video_qa_args(self):
self.parser.add_argument(
"--task", type=str,
choices=["action", "transition", "frameqa", "msrvtt_qa"],
help="TGIF-QA tasks and MSRVTT-QA")
self.parser.add_argument("--loss_type", type=str, default="ce",
help="loss type, will be overwritten later")
self.parser.add_argument("--classifier", type=str, default="mlp",
choices=["mlp", "linear"],
help="classifier type")
self.parser.add_argument(
"--cls_hidden_scale", type=int, default=2,
help="scaler of the intermediate linear layer dimension for mlp classifier")
# for frameQA msrvtt_qa
self.parser.add_argument("--ans2label_path", type=str, default=None,
help="path to {answer: label} file")
# manually setup config by task type
args = self.parse_args()
if args.max_n_example_per_group != 1:
Warning(f"For TGIF-QA, most GIF is only paired with a single example, no need to"
f"use max_n_example_per_group={args.max_n_example_per_group}"
f"larger than 1. Automatically reset to 1.")
args.max_n_example_per_group = 1
if os.path.exists(args.ans2label_path):
num_answers = len(json.load(open(args.ans2label_path, "r")))
else:
num_answers = 0
if args.task in ["msrvtt_qa", "msvd_qa"]:
args.num_labels = max(num_answers, 1500)
args.loss_type = "ce"
else:
raise NotImplementedError
return args
shared_configs = SharedConfigs()
================================================
FILE: src/datasets/data_utils.py
================================================
import torch
import random
import torchvision.transforms as transforms
from torchvision.transforms.functional import pad as img_pad
from torchvision.transforms.functional import resize as img_resize
from torch.nn.functional import interpolate as img_tensor_resize
from torch.nn.functional import pad as img_tensor_pad
from torch.nn.modules.utils import _quadruple
from src.utils.basic_utils import flat_list_of_lists
import numbers
import numpy as np
from PIL import Image
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
def mask_batch_text_tokens(
inputs, tokenizer, mlm_probability=0.15, is_train=True):
""" modified from transformers.data.data_collator
Args:
inputs: (B, L), 2D torch.Tensor, does not work for 1D. It has already been padded.
tokenizer:
mlm_probability: float
is_train: if True use random masking, else mask tokens at fixed position to remove randomness in evaluation.
"""
if tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training
# (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, mlm_probability)
special_tokens_mask = [
tokenizer.get_special_tokens_mask(
val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(
special_tokens_mask, dtype=torch.bool), value=0.0)
if tokenizer._pad_token is not None:
padding_mask = labels.eq(tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(
torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(
tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(
torch.full(labels.shape, 0.5)
).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(
len(tokenizer), labels.shape,
dtype=torch.long) # len(tokenizer) == #vocab
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
def select_batch_text_pivots(
inputs, tokenizer, ent2id, mpm_probability=1.0, is_train=True):
""" Given a input text sequence, generate masks and prototype labels such that:
1) not to mask special token ([CLS], [SEP], [MASK], [PAD]);
2) always mask all BPE in a word together.
Args:
"""
if tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for as pivots
probability_matrix = torch.full(labels.shape, mpm_probability)
# ignore [CLS] [SEP] [MASK] tokens
special_tokens_mask = [
tokenizer.get_special_tokens_mask(
val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
# ignore [PAD] tokens
if tokenizer._pad_token is not None:
padding_mask = labels.eq(tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
# create masking indices
pivot_indices = torch.bernoulli(probability_matrix).bool()
labels[special_tokens_mask] = -100 # We only compute loss on masked tokens
labels[~pivot_indices] = -100 # We only compute loss on masked tokens
# selected pivot positions: (1) non-special token; (2) selected based on mpm probability.
text_pivots_pos = (labels > 0).nonzero()
for tpp in text_pivots_pos:
orig_tpp = tpp.clone()
bth = tpp[0]
orig_text_pos = orig_tpp[1]
text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]]])[0]
next_text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]+1]])[0] if tpp[1]+1 < inputs.shape[1] else None
# TODO may consider support encoding beyond sentencepiece.
if text_token.startswith('##'):
# if it is a byte pair, backtrace until we find the prefix
orig_text_token = ''
while True:
if not text_token.startswith('##'):
orig_text_token = text_token + orig_text_token
break
else:
orig_text_token = text_token[2:] + orig_text_token
tpp[1] -= 1
text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]]])[0]
try:
# assign prototype labels to all the sentencepiece bytes in the pivot word
labels[bth][tpp[1]: orig_text_pos + 1] = ent2id[orig_text_token]
pivot_indices[bth][tpp[1]: orig_text_pos + 1] = True
except KeyError:
# we do not have this word for prototype
labels[bth][orig_text_pos] = -100
elif next_text_token is not None and next_text_token.startswith('##'):
# if it is a prefix, forward-trace until we find the end of the byte pair
full_text_token = text_token
while True:
tpp[1] += 1
text_token = tokenizer.convert_ids_to_tokens([inputs[bth][tpp[1]]])[0]
if not text_token.startswith('##'):
# find the next prefix/word
break
else:
# find continuing bytes
full_text_token = full_text_token + text_token[2:]
try:
# assign prototype labels to all the sentencepiece bytes in the pivot word
labels[bth][orig_text_pos: tpp[1]] = ent2id[full_text_token]
pivot_indices[bth][orig_text_pos: tpp[1]] = True
except KeyError:
# we do not have this word for prototype
labels[bth][orig_text_pos] = -100
else:
# the word is treated in whole by BERT tokenizer
try:
labels[bth][tpp[1]] = ent2id[text_token]
except KeyError:
# we do not have this word for prototype
labels[bth][tpp[1]] = -100
# restore mask if the word is not in the entity list
pivot_indices[labels==-100] = False
return pivot_indices, labels
def image_to_tensor(image: np.ndarray, keepdim: bool = True) -> torch.Tensor:
"""Converts a numpy image to a PyTorch 4d tensor image.
Args:
image (numpy.ndarray): image of the form :math:`(H, W, C)`, :math:`(H, W)` or
:math:`(B, H, W, C)`.
keepdim (bool): If ``False`` unsqueeze the input image to match the shape
:math:`(B, H, W, C)`. Default: ``True``
Returns:
torch.Tensor: tensor of the form :math:`(B, C, H, W)` if keepdim is ``False``,
:math:`(C, H, W)` otherwise.
"""
if not isinstance(image, (np.ndarray,)):
raise TypeError("Input type must be a numpy.ndarray. Got {}".format(
type(image)))
if len(image.shape) > 4 or len(image.shape) < 2:
raise ValueError(
"Input size must be a two, three or four dimensional array")
input_shape = image.shape
tensor: torch.Tensor = torch.from_numpy(image)
if len(input_shape) == 2:
# (H, W) -> (1, H, W)
tensor = tensor.unsqueeze(0)
elif len(input_shape) == 3:
# (H, W, C) -> (C, H, W)
tensor = tensor.permute(2, 0, 1)
elif len(input_shape) == 4:
# (B, H, W, C) -> (B, C, H, W)
tensor = tensor.permute(0, 3, 1, 2)
keepdim = True # no need to unsqueeze
else:
raise ValueError(
"Cannot process image with shape {}".format(input_shape))
return tensor.unsqueeze(0) if not keepdim else tensor
def get_padding(image, max_w, max_h, pad_all=False):
# keep the images to upper-left corner
if isinstance(image, torch.Tensor):
h, w = image.shape[-2:]
else:
w, h = image.size
h_padding, v_padding = max_w - w, max_h - h
if pad_all:
h_padding /= 2
v_padding /= 2
l_pad = h_padding if h_padding % 1 == 0 else h_padding+0.5
t_pad = v_padding if v_padding % 1 == 0 else v_padding+0.5
r_pad = h_padding if h_padding % 1 == 0 else h_padding-0.5
b_pad = v_padding if v_padding % 1 == 0 else v_padding-0.5
else:
l_pad, t_pad = 0, 0
r_pad, b_pad = h_padding, v_padding
if isinstance(image, torch.Tensor):
padding = (int(l_pad), int(r_pad), int(t_pad), int(b_pad))
else:
padding = (int(l_pad), int(t_pad), int(r_pad), int(b_pad))
return padding
class ImagePad(object):
def __init__(self, max_w, max_h, fill=0, padding_mode='constant'):
assert isinstance(fill, (numbers.Number, str, tuple))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
self.max_w = max_w
self.max_h = max_h
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be padded.
Returns:
PIL Image: Padded image.
"""
if isinstance(img, torch.Tensor):
paddings = _quadruple(get_padding(img, self.max_w, self.max_h))
return img_tensor_pad(
img, paddings,
self.padding_mode, self.fill)
return img_pad(
img, get_padding(img, self.max_w, self.max_h),
self.fill, self.padding_mode)
def __repr__(self):
return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\
format(self.fill, self.padding_mode)
def get_resize_size(image, max_size):
"""
Args:
image: PIL Image or torch.tensor
max_size:
Returns:
Note the height/width order difference
>>> pil_img = Image.open("raw_img_tensor.jpg")
>>> pil_img.size
(640, 480) # (width, height)
>>> np_img = np.array(pil_img)
>>> np_img.shape
(480, 640, 3) # (height, width, 3)
"""
# note the order of height and width for different inputs
if isinstance(image, torch.Tensor):
# width, height = image.shape[-2:]
height, width = image.shape[-2:]
else:
width, height = image.size
if height >= width:
ratio = width*1./height
new_height = max_size
new_width = new_height * ratio
else:
ratio = height*1./width
new_width = max_size
new_height = new_width * ratio
size = (int(new_height), int(new_width))
return size
class VideoRandomSquareCrop(object):
def __init__(self, crop_size, p=0.5):
assert isinstance(crop_size, int)
self.crop_size = crop_size
self.p = p
def __call__(self, video):
"""
Args:
img (torch.tensor): video to be cropped.
Returns:
torch.tensor: cropped video.
"""
if isinstance(video, torch.Tensor):
if len(video.shape) == 4:
b, t, h, w = video.shape
else:
raise RuntimeError('Expecting 4-dimensional tensor of shape (b,t,h,w), got {}'.format(video.shape))
# if random.uniform(0, 1) < self.p:
# video = torch.flip(video, (3,))
x = random.randint(0, h - self.crop_size)
y = random.randint(0, w - self.crop_size)
return video[:, :, x: x + self.crop_size, y: y + self.crop_size]
else:
raise NotImplementedError('Support only torch.Tensor as input, got {}'.format(type(video)))
class VideoResizeSquare(object):
def __init__(self, out_size, interpolation='nearest'):
assert isinstance(out_size, int)
self.out_size = out_size
self.interpolation = interpolation
def __call__(self, video):
"""
Args:
img (torch.tensor): video to be scaled.
Returns:
torch.tensor: Rescaled video.
"""
if isinstance(video, torch.Tensor):
if len(video.shape) == 4:
t, c, h, w = video.shape
assert c == 3, 'Expecting 3-channel color video, got video of shape {}'.format(video.shape)
else:
raise RuntimeError('Expecting 4-dimensional tensor of shape (b,t,h,w), got {}'.format(video.shape))
short_side = h if h < w else w
# scaling_factor = self.out_size / short_side
# new_h = int(h * scaling_factor)
# new_w = int(w * scaling_factor)
resized_video = img_tensor_resize(video, size=((self.out_size, self.out_size)), mode=self.interpolation)
return resized_video
else:
# in other data class, the order of shape might be different.
raise NotImplementedError('Support only torch.Tensor as input, got {}'.format(type(video)))
def __repr__(self):
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(
self.out_size, self.interpolation)
class ImageResize(object):
"""Resize the input image (torch.tensor) to the given size.
Args:
max_size (int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, max_size, interpolation=Image.BILINEAR):
assert isinstance(max_size, int)
self.max_size = max_size
self.interpolation = interpolation
def __call__(self, img):
"""
Args:
img (torch.tensor): Image to be scaled.
Returns:
torch.tensor: Rescaled image.
"""
if isinstance(img, torch.Tensor):
assert isinstance(self.interpolation, str)
return img_tensor_resize(
img, size=get_resize_size(img, self.max_size),
mode=self.interpolation, align_corners=False)
return img_resize(
img, get_resize_size(img, self.max_size), self.interpolation)
def __repr__(self):
interpolate_str = _pil_interpolation_to_str[self.interpolation]
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(
self.size, interpolate_str)
def get_imagenet_transform(min_size=600, max_size=1000):
"""parameters from https://github.com/pytorch/examples/blob/master/imagenet/main.py
This simply crop the center square from the image
"""
if min_size != 600:
import warnings
warnings.warn(f'Warning: min_size is not used in image transform, '
f'setting min_size will have no effect.')
return transforms.Compose([
ImageResize(max_size, Image.BILINEAR), # longer side will be resized to 1000
ImagePad(max_size, max_size), # pad to 1000 * 1000
])
class ImageNorm(object):
"""Apply Normalization to Image Pixels on GPU
"""
def __init__(self, mean, std):
self.mean = torch.tensor(mean).cuda().view(1, 1, 3, 1, 1)
self.std = torch.tensor(std).cuda().view(1, 1, 3, 1, 1)
# assert max(std) <= 1 and min(std) >= 0\
# or max(mean) <= 1 and min(mean) >= 0,\
# "Please provide mean or std within range [0, 1]"
def __call__(self, img):
"""
Args:
img: float image tensors, (B, N, 3, H, W)
Returns:
img: normalized float image tensors
"""
if torch.max(img) > 1 and self.mean.max() <= 1:
img.div_(255.)
return img.sub_(self.mean).div_(self.std)
def chunk_list(examples, chunk_size=2, pad_to_divisible=True):
"""
Args:
examples: iterable, examples grouped by image/video
chunk_size: int, number of examples in each chunk.
pad_to_divisible: bool, pad the examples to be divisible by chunk_size.
>>> test_examples = [3, 4, 5, 6, 7]
>>> chunk_list(test_examples, chunk_size=2, pad_to_divisible=True)
[[3, 4], [5, 6], [7, 7]] # the lst element has some randomness
>>> chunk_list(test_examples, chunk_size=2, pad_to_divisible=False)
[[3, 4], [5, 6], [7]]
"""
n_examples = len(examples)
remainder = n_examples % chunk_size
if pad_to_divisible and remainder > 0:
n_pad = chunk_size - remainder
pad = random.choices(examples, k=n_pad) # with replacement
examples = examples + pad
n_examples = len(examples)
remainder = 0
chunked_examples = []
n_chunks = int(n_examples / chunk_size)
n_chunks = n_chunks + 1 if remainder > 0 else n_chunks
for i in range(n_chunks):
chunked_examples.append(examples[i*chunk_size: (i+1)*chunk_size])
return chunked_examples
def mk_input_group(key_grouped_examples, max_n_example_per_group=1, is_train=True,
example_unique_key=None):
""" Re-organize examples into groups. Each input group will have a single image paired
with X (X=max_n_example_per_img) examples. Images with total #examples > X will be
split into multiple groups. In the case a group has < X examples, we will copy
the examples to make the group has X examples.
Args:
key_grouped_examples: dict, each key is image/video id,
each value is a list(example) associated with this image/video
max_n_example_per_group: int, pair max #examples with each image/video.
Note that each image can have multiple groups.
is_train: bool, if True, copy the examples to make sure each input
group has max_n_example_per_group examples.
example_unique_key: str, used to make sure no inputs are discarded by matching
the input and output ids specified by `example_unique_key`
"""
input_groups = [] # each element is (id, list(example))
for k, examples in key_grouped_examples.items():
chunked_examples = chunk_list(examples,
chunk_size=max_n_example_per_group,
pad_to_divisible=is_train)
for c in chunked_examples:
# if len(c) == 0:
# continue
input_groups.append((k, c))
if example_unique_key is not None:
print(f"Using example_unique_key {example_unique_key} to check whether input and output ids m")
# sanity check: make sure we did not discard any input example by accident.
input_question_ids = flat_list_of_lists(
[[sub_e[example_unique_key] for sub_e in e] for e in key_grouped_examples.values()])
output_question_ids = flat_list_of_lists(
[[sub_e[example_unique_key] for sub_e in e[1]] for e in input_groups])
assert set(input_question_ids) == set(output_question_ids), "You are missing "
return input_groups
# def repeat_tensor_rows(raw_tensor, row_repeats):
# """ repeat raw_tensor[i] row_repeats[i] times.
# Args:
# raw_tensor: (B, *)
# row_repeats: list(int), len(row_repeats) == len(raw_tensor)
# """
# assert len(raw_tensor) == len(raw_tensor), "Has to be the same length"
# if sum(row_repeats) == len(row_repeats):
# return raw_tensor
# else:
# indices = torch.LongTensor(
# flat_list_of_lists([[i] * r for i, r in enumerate(row_repeats)])
# ).to(raw_tensor.device)
# return raw_tensor.index_select(0, indices)
================================================
FILE: src/datasets/dataloader.py
================================================
"""
modified from UNITER codebase
A meta data loader for sampling from different datasets / training tasks
A prefetch loader to speedup data loading
"""
import random
import torch
from torch.utils.data import DataLoader
from src.utils.distributed import any_broadcast
class MetaLoader(object):
""" wraps multiple data loader """
def __init__(self, loaders, accum_steps=1, distributed=False):
assert isinstance(loaders, dict)
self.name2loader = {}
self.name2iter = {}
self.sampling_pools = []
n_batches_in_epoch = 0
for n, l in loaders.items():
if isinstance(l, tuple):
l, r = l
elif isinstance(l, DataLoader):
r = 1
else:
raise ValueError()
n_batches_in_epoch += len(l.dataset) * r / l.batch_size
self.name2loader[n] = l
self.name2iter[n] = iter(l)
self.sampling_pools.extend([n]*r)
self.n_batches_in_epoch = n_batches_in_epoch
self.accum_steps = accum_steps
self.distributed = distributed
self.step = 0
def __iter__(self):
""" this iterator will run indefinitely """
task = self.sampling_pools[0]
while True:
if self.step % self.accum_steps == 0:
task = random.choice(self.sampling_pools)
if self.distributed:
# make sure all process is training same task
task = any_broadcast(task, 0)
self.step += 1
iter_ = self.name2iter[task]
try:
batch = next(iter_)
except StopIteration:
iter_ = iter(self.name2loader[task])
batch = next(iter_)
self.name2iter[task] = iter_
yield task, batch
def move_to_cuda(batch):
if isinstance(batch, torch.Tensor):
return batch.cuda(non_blocking=True)
elif isinstance(batch, list):
new_batch = [move_to_cuda(t) for t in batch]
elif isinstance(batch, tuple):
new_batch = tuple(move_to_cuda(t) for t in batch)
elif isinstance(batch, dict):
new_batch = {n: move_to_cuda(t) for n, t in batch.items()}
else:
return batch
return new_batch
def record_cuda_stream(batch):
if isinstance(batch, torch.Tensor):
batch.record_stream(torch.cuda.current_stream())
elif isinstance(batch, list) or isinstance(batch, tuple):
for t in batch:
record_cuda_stream(t)
elif isinstance(batch, dict):
for t in batch.values():
record_cuda_stream(t)
else:
pass
class PrefetchLoader(object):
"""
overlap compute and cuda data transfer
(copied and then modified from nvidia apex)
"""
def __init__(self, loader, img_normalize=None):
self.loader = loader
self.stream = torch.cuda.Stream()
self.img_normalize = img_normalize
def __iter__(self):
loader_it = iter(self.loader)
self.preload(loader_it)
batch = self.next(loader_it)
while batch is not None:
is_tuple = isinstance(batch, tuple)
if is_tuple:
task, batch = batch
batch["visual_inputs"] = batch["visual_inputs"].float()
if self.img_normalize is not None:
batch["visual_inputs"] = self.img_normalize(
batch["visual_inputs"])
if "crop_visual_inputs" in batch:
batch["crop_visual_inputs"] = batch["crop_visual_inputs"].float()
batch["crop_visual_inputs"] = self.img_normalize(
batch["crop_visual_inputs"])
if "context_visual_inputs" in batch:
batch["context_visual_inputs"] = batch["context_visual_inputs"].float()
batch["context_visual_inputs"] = self.img_normalize(
batch["context_visual_inputs"])
if is_tuple:
yield task, batch
else:
yield batch
batch = self.next(loader_it)
def __len__(self):
return len(self.loader)
def preload(self, it):
try:
self.batch = next(it)
except StopIteration:
self.batch = None
return
# if record_stream() doesn't work, another option is to make sure
# device inputs are created on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input,
# device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target,
# device='cuda')
# Need to make sure the memory allocated for next_* is not still in use
# by the main stream at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.batch = move_to_cuda(self.batch)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this
# side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
def next(self, it):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
if batch is not None:
record_cuda_stream(batch)
self.preload(it)
return batch
def __getattr__(self, name):
method = self.loader.__getattribute__(name)
return method
class InfiniteIterator(object):
"""iterate an iterable oobject infinitely"""
def __init__(self, iterable):
self.iterable = iterable
self.iterator = iter(iterable)
def __iter__(self):
while True:
try:
batch = next(self.iterator)
except StopIteration:
self.iterator = iter(self.iterable)
batch = next(self.iterator)
yield batch
================================================
FILE: src/datasets/dataset_base.py
================================================
from torch.utils.data import Dataset
from PIL import Image
import io
import av
import torch
import numpy as np
import lmdb
import random
import decord
from decord import VideoReader
from src.datasets.data_utils import (
ImageResize, ImagePad, image_to_tensor)
from src.utils.load_save import LOGGER
decord.bridge.set_bridge("torch")
class AlproBaseDataset(Dataset):
"""
datalist: list(dicts) # lightly pre-processed
{
"type": "image",
"filepath": "/abs/path/to/COCO_val2014_000000401092.jpg",
"text": "A plate of food and a beverage are on a table.",
# should be tokenized and digitized first?
...
}
tokenizer:
max_img_size: int,
max_txt_len: int, max text sequence length, including special tokens.
fps: float, frame per second
num_frm: #frames to use as input.
"""
def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type='lmdb', fps=3, num_frm=3,
frm_sampling_strategy="rand", max_img_size=-1, max_txt_len=20):
self.fps = fps
self.num_frm = num_frm
self.frm_sampling_strategy = frm_sampling_strategy
self.datalist = datalist
self.tokenizer = tokenizer
self.max_txt_len = max_txt_len
self.max_img_size = max_img_size
self.img_resize = ImageResize(
max_img_size,
"bilinear") # longer side will be resized to 1000
self.img_pad = ImagePad(
max_img_size, max_img_size) # pad to 1000 * 1000
self.img_db_type = img_db_type
assert img_db_type in ['lmdb', 'rawvideo'], "Invalid type for img_db_type, expected {'lmdb', 'rawvideo'}, found {}.".format(img_db_type)
if self.img_db_type == 'lmdb':
self.env = lmdb.open(
img_lmdb_dir, readonly=True,
create=False) # readahead=not _check_distributed()
self.txn = self.env.begin(buffers=True)
else:
self.img_db_dir = img_lmdb_dir
def __len__(self):
return len(self.datalist)
def __getitem__(self, index):
raise NotImplementedError
def _load_img(self, img_id):
"""Load and apply transformation to image
Returns:
torch.float, in [0, 255], (n_frm=1, c, h, w)
"""
raw_img = load_decompress_img_from_lmdb_value(
self.txn.get(str(img_id).encode("utf-8"))
)
image_np = np.array(raw_img, dtype=np.uint8) # (h, w, c)
raw_img_tensor = image_to_tensor(
image_np, keepdim=False).float() # (c, h, w) [0, 255]
resized_img = self.img_resize(raw_img_tensor)
transformed_img = self.img_pad(
resized_img) # (n_frm=1, c, h, w)
return transformed_img
@classmethod
def _is_extreme_aspect_ratio(cls, tensor, max_ratio=5.):
""" find extreme aspect ratio, where longer side / shorter side > max_ratio
Args:
tensor: (*, H, W)
max_ratio: float, max ratio (>1).
"""
h, w = tensor.shape[-2:]
return h / float(w) > max_ratio or h / float(w) < 1 / max_ratio
def _load_video(self, video_id, num_clips=None, clip_idx=None,
safeguard_duration=False, video_max_pts=None):
"""Load and sample frames from video.
Apply transformation to the sampled frames.
Sample a clip:
- random: set num_clips and clip_idx to be None
- uniform: set num_clips=N, clip_idx=idx. e.g., num_clips=3
and clip_idx=1 will first segment the video into 3 clips,
then sample the 2nd clip.
Returns:
torch.float, in [0, 255], (n_frm=T, c, h, w)
"""
assert (num_clips is None) == (clip_idx is None), "Both None, or both not None"
# (T, C, H, W) [0, 255]
io_stream = io.BytesIO(self.txn.get(str(video_id).encode("utf-8")))
raw_sampled_frms, video_max_pts = extract_frames_from_video_binary(
io_stream,
target_fps=self.fps,
num_frames=self.num_frm,
multi_thread_decode=False,
sampling_strategy=self.frm_sampling_strategy,
num_clips=num_clips,
clip_idx=clip_idx,
safeguard_duration=safeguard_duration,
video_max_pts=video_max_pts
)
if raw_sampled_frms is None:
return None, None
elif self._is_extreme_aspect_ratio(raw_sampled_frms, max_ratio=5.):
print(
f"Found extreme aspect ratio for video id {video_id}. Skip it")
return None, None
raw_sampled_frms = raw_sampled_frms.float()
resized_frms = self.img_resize(raw_sampled_frms)
padded_frms = self.img_pad(resized_frms)
return padded_frms, video_max_pts
def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
try:
if not height or not width:
vr = VideoReader(video_path)
else:
vr = VideoReader(video_path, width=width, height=height)
vlen = len(vr)
if start_time or end_time:
assert fps > 0, 'must provide video fps if specifying start and end time.'
start_idx = min(int(start_time * fps), vlen)
end_idx = min(int(end_time * fps), vlen)
else:
start_idx, end_idx = 0, vlen
if self.frm_sampling_strategy == 'uniform':
frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
elif self.frm_sampling_strategy == 'nlvl_uniform':
frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm).astype(int)
elif self.frm_sampling_strategy == 'nlvl_rand':
frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm).astype(int)
# generate some random perturbations
strides = [frame_indices[i] - frame_indices[i-1] for i in range(1, len(frame_indices))] + [vlen - frame_indices[-1]]
pertube = np.array([np.random.randint(0, stride) for stride in strides])
frame_indices = frame_indices + pertube
elif self.frm_sampling_strategy == 'rand':
frame_indices = sorted(random.sample(range(vlen), self.num_frm))
elif self.frm_sampling_strategy == 'headtail':
frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
frame_indices = frame_indices_head + frame_indices_tail
else:
raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
raw_sample_frms = vr.get_batch(frame_indices)
except Exception as e:
return None
raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
return raw_sample_frms
def img_collate(imgs):
"""
Args:
imgs:
Returns:
torch.tensor, (B, 3, H, W)
"""
w = imgs[0].width
h = imgs[0].height
tensor = torch.zeros(
(len(imgs), 3, h, w), dtype=torch.uint8).contiguous()
for i, img in enumerate(imgs):
nump_array = np.array(img, dtype=np.uint8)
if (nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
# (H, W, 3) --> (3, H, W)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor
================================================
FILE: src/datasets/dataset_pretrain_sparse.py
================================================
import os
import json
import random
import torch
import spacy
from torch.utils.data.dataloader import default_collate
from src.utils.logger import LOGGER
from src.utils.basic_utils import flat_list_of_lists, save_frames_grid
from src.datasets.data_utils import VideoRandomSquareCrop, VideoResizeSquare, mask_batch_text_tokens, select_batch_text_pivots
from src.datasets.dataset_base import AlproBaseDataset, img_collate
from src.datasets.randaugment import TemporalConsistentRandomAugment, RandomAugment
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
class AlproPretrainSparseDataset(AlproBaseDataset):
"""
datalist: list(tuples) each tuple is (img_id, list(dicts)),
each dict {
"type": "image",
"filepath": "/abs/path/to/COCO_val2014_000000401092.jpg",
"text": "A plate of food and a beverage are on a table.", # should be tokenized and digitized first?
...
}
tokenizer:
max_img_size: int,
max_txt_len: int, max text sequence length, including special tokens.
vis_format: str, image or video, used to decide data loading method.
"""
def __init__(self, datalist, tokenizer, img_lmdb_dir, img_db_type, txt_dir,
video_fmt='.mp4', crop_size=256, resize_size=288, fps=3, num_frm=3, frm_sampling_strategy="rand",
max_img_size=1000, max_txt_len=20,
use_itm=True, is_train=True):
super(AlproPretrainSparseDataset, self).__init__(
datalist, tokenizer, img_lmdb_dir,
img_db_type=img_db_type,
fps=fps,
num_frm=num_frm,
frm_sampling_strategy=frm_sampling_strategy,
max_img_size=max_img_size,
max_txt_len=max_txt_len)
self.use_itm = use_itm
self.txt_dir = txt_dir
self.video_fmt = video_fmt
self.crop_size = crop_size
self.video_random_cropper = VideoRandomSquareCrop(crop_size)
self.resize_size = resize_size
self.is_train = is_train
if self.is_train:
self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip'])
else:
self.randaug = None
def __len__(self):
return len(self.datalist)
def __getitem__(self, index):
start_time = None
end_time = None
# fetch video
num_retries = 10 # skip error videos
for _ in range(num_retries):
data_sample = self.datalist.iloc[index]
video_id = str(data_sample.video_id)
txt_len = int(data_sample.txt_len)
if hasattr(data_sample, 'text'):
text = data_sample.text.strip()
else:
raise NotImplementedError("Un-supported text annotation format.")
# fetch video
video_path = os.path.join(self.img_db_dir, video_id + self.video_fmt)
# read with retries
for i in range(3):
img_array = self._load_video_from_path_decord(video_path, height=self.resize_size, width=self.resize_size)
if img_array is not None:
break
if img_array is not None:
t, c, h, w = img_array.shape
# Select a random video if the current video was not able to access.
if img_array is None:
LOGGER.info(f"Failed to load examples with video: {video_path}. "
f"Will randomly sample an example as a replacement.")
index = random.randint(0, len(self) - 1)
continue
else:
# square crop
img_array = self.video_random_cropper(img_array)
if self.randaug:
img_array = self.randaug(img_array.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
break
else:
raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
examples = [{'text_str': text, 'itm_label': 1}]
return dict(
img=img_array, # (T, C, H, W)
examples=examples,
n_examples=len(examples), # used to create image feature copies.
type='video'
)
class PretrainImageTextDataset(Dataset):
def __init__(self, datalist, tokenizer, is_train=True, crop_size=256, resize_size=288, num_frm=4, max_txt_len=40):
self.datalist = datalist
self.max_txt_len = max_txt_len
self.crop_size = crop_size
self.resize_size = resize_size
self.num_frms = num_frm
self.is_train = is_train
self.transform = transforms.Compose([
transforms.RandomResizedCrop(self.crop_size, scale=(0.2, 1.0), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,7,isPIL=True,augs=['Identity','Brightness','Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate'])
])
def __len__(self):
return len(self.datalist)
def __getitem__(self, index):
start_time = None
end_time = None
# fetch video
num_retries = 10 # skip error videos
for _ in range(num_retries):
data_sample = self.datalist[index]
try:
if type(data_sample['caption']) == list:
text = random.choice(data_sample['caption'])
else:
text = data_sample['caption']
img_path = data_sample['image']
img_arr = Image.open(img_path).convert('RGB')
img_arr = self.transform(img_arr)
img_arr = np.asarray(img_arr, dtype=np.float32).transpose(2, 0, 1)
img_arr = torch.from_numpy(img_arr).unsqueeze(0)
img_arr = img_arr.repeat(self.num_frms, 1, 1, 1)
except Exception as e:
img_arr = None
if img_arr is not None:
t, c, h, w = img_arr.shape
# Select a random video if the current video was not able to access.
if img_arr is None:
LOGGER.info(f"Failed to load examples with image: {img_path}. "
f"Will randomly sample an example as a replacement.")
index = random.randint(0, len(self) - 1)
continue
else:
break
else:
raise RuntimeError(f"Failed to fetch image after {num_retries} retries.")
examples = [{'text_str': text, 'itm_label': 1}]
return dict(
img=img_arr, # (T, C, H, W)
examples=examples,
n_examples=len(examples), # used to create image feature copies.
type='img'
)
class PretrainCollator(object):
"""is_train is kept here if we want to remove
the randomness during validation of MLM accuracy.
In that case, instantiate two PretrainCollator"""
def __init__(self, tokenizer,
mlm=True, mlm_probability=0.15,
patch_size=16,
mpm=True,
max_length=20, is_train=True):
self.tokenizer = tokenizer
self.mlm = mlm
self.mlm_probability = mlm_probability
self.max_length = max_length
self.is_train = is_train
self.mpm = mpm
self.patch_size = patch_size
def collate_batch(self, batch):
if isinstance(batch[0]["img"], torch.Tensor):
v_collate = default_collate
else:
v_collate = img_collate
visual_inputs = v_collate([d["img"] for d in batch]) # (B, #frm=1 or T, 3, H, W)
# group data
text_examples = flat_list_of_lists([d["examples"] for d in batch])
n_examples_list = [d["n_examples"] for d in batch] # (B, )
# group elements data
batch_enc = self.tokenizer.batch_encode_plus(
[d["text_str"] for d in text_examples],
max_length=self.max_length,
padding='max_length',
return_tensors="pt",
truncation=True
)
text_input_ids = batch_enc.input_ids # (B, L)
text_input_ids_no_mask = text_input_ids.clone()
if self.mlm:
text_input_ids, mlm_labels = mask_batch_text_tokens(
text_input_ids, self.tokenizer,
is_train=self.is_train) # make mlm data
else:
text_input_ids, mlm_labels = text_input_ids, None
text_input_mask = batch_enc.attention_mask # (B, L)
itm_labels = default_collate(
[d["itm_label"] for d in text_examples]) # (B, )
erase_elems = [random_erase(e, patch_size=self.patch_size) for e in visual_inputs.clone()]
if self.mpm:
crop_visual_inputs = v_collate([elems[0] for elems in erase_elems])
mpm_masks = v_collate([elems[1] for elems in erase_elems])
context_visual_inputs = v_collate([elems[2] for elems in erase_elems])
return dict(
visual_inputs=visual_inputs, # (B, #frm=1 or T, H, W, C)
crop_visual_inputs=crop_visual_inputs,
context_visual_inputs=context_visual_inputs,
mpm_mask=mpm_masks,
text_input_ids=text_input_ids_no_mask,
mlm_text_input_ids=text_input_ids,
mlm_labels=mlm_labels,
text_input_mask=text_input_mask, # used to exclude [PAD] token
itm_labels=itm_labels,
n_examples_list=n_examples_list, # used to create image feature copies.
type=batch[0]['type']
)
else:
return dict(
visual_inputs=visual_inputs, # (B, #frm=1 or T, H, W, C)
text_input_ids=text_input_ids_no_mask,
mlm_text_input_ids=text_input_ids,
mlm_labels=mlm_labels,
text_input_mask=text_input_mask, # used to exclude [PAD] token
itm_labels=itm_labels,
n_examples_list=n_examples_list, # used to create image feature copies.
type=batch[0]['type']
)
def random_erase(input_img, patch_size, s_l=0.3, s_h=0.5, r_1=0.3, r_2=1/0.3, v_l=0, v_h=255):
assert input_img.ndim == 4
img_t, img_c, img_h, img_w = input_img.shape
while True:
s = np.random.uniform(s_l, s_h) * img_h * img_w
r = np.random.uniform(r_1, r_2)
w = int(np.sqrt(s / r))
h = int(np.sqrt(s * r))
left = np.random.randint(0, img_w)
top = np.random.randint(0, img_h)
w = w - w % patch_size
h = h - h % patch_size
left = left - left % patch_size
top = top - top % patch_size
if left + w <= img_w and top + h <= img_h:
break
context_img = input_img.clone()
context_img[:, :, top: top + h, left: left + w] = 0
input_img = input_img[:, :, top: top + h, left: left + w]
pad = (left, img_w - left - w, top, img_h - top - h)
input_img = torch.nn.functional.pad(input_img, pad=pad, mode='constant', value=0.0)
img_masks = torch.ones_like(input_img)
img_masks[:, :, top: top+h, left: left+w] = 0
img_masks = torch.nn.functional.avg_pool2d(img_masks.float(), kernel_size=(patch_size, patch_size), stride=patch_size)
img_masks = torch.mean(img_masks, dim=(0, 1))
return input_img, img_masks, context_img
================================================
FILE: src/datasets/dataset_video_qa.py
================================================
import os
import torch
import random
import numpy as np
import copy
from torch.utils.data.dataloader import default_collate
from src.utils.basic_utils import flat_list_of_lists
from src.utils.load_save import LOGGER
from src.datasets.dataset_base import AlproBaseDataset
from src.datasets.randaugment import TemporalConsistentRandomAugment
class AlproVideoQADataset(AlproBaseDataset):
""" This should work for both train and test (where labels are not available).
task_type: str, one of [action, frameqa, transition]
where action and transition are multiple-choice QA,
frameqa is opened QA similar to VQA.
datalist: list(tuples) each tuple is (img_id, list(dicts)),
each dict
tokenizer:
max_img_size: int,
max_txt_len: int, max text sequence length, including special tokens.
return_label: bool, whether return label in __getitem__
random_sample_clips:
"""
open_ended_qa_names = ["frameqa", "msrvtt_qa", "msvd_qa"]
def __init__(self, task_type, datalist, tokenizer, img_lmdb_dir,
fps=3, num_frm=3, frm_sampling_strategy="rand",
max_img_size=1000, max_txt_len=20, ans2label=None,
ensemble_n_clips=1, return_label=True, is_train=False, random_sample_clips=True,
video_fmt='.mp4', img_db_type='lmdb'):
super(AlproVideoQADataset, self).__init__(
datalist, tokenizer, img_lmdb_dir, img_db_type=img_db_type,
fps=fps, num_frm=num_frm,
frm_sampling_strategy=frm_sampling_strategy,
max_img_size=max_img_size, max_txt_len=max_txt_len)
self.ensemble_n_clips = ensemble_n_clips
self.return_label = return_label
self.is_train = is_train
self.task_type = task_type
self.ans2label = ans2label
self.num_labels = len(ans2label)
self.random_sample_clips = random_sample_clips
self.label2ans = {v: k for k, v in ans2label.items()}
self.qid2data = {d["question_id"]: d for group in datalist for d in group[1]}
self.video_fmt = video_fmt
if self.is_train:
self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip'])
else:
self.randaug = None
def __len__(self):
return len(self.datalist)
def __getitem__(self, index):
# skip error videos:
num_retries = 5
for _ in range(num_retries):
vid_id, examples = self.datalist[index] # one video with multiple examples
if self.ensemble_n_clips > 1:
raise NotImplementedError('Do not support multiple clips for now.')
else:
video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt)
vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
# Select a random video if the current video was not able to access.
if vid_frm_array is None:
LOGGER.info(f"Failed to load examples with video: {vid_id}. "
f"Will randomly sample an example as a replacement.")
index = random.randint(0, len(self) - 1)
continue
if self.randaug:
vid_frm_array = self.randaug(vid_frm_array.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
examples = [self._get_single_example(e) for e in examples]
return dict(
vid=vid_frm_array,
examples=examples,
n_examples=len(examples) # used to create image feature copies.
)
else:
raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
def _get_single_example(self, data):
example = dict(
q_str=data["question"],
question_id=data["question_id"],
label=data["answer"]
)
if self.task_type in self.open_ended_qa_names:
if self.return_label:
example["label"] = self.ans2label[example["label"]]
if not self.return_label:
example["label"] = None
return example
def evaluate_qa(self, results):
"""
Args:
results: list(dict),
each dict is
{
"question_id": int,
"answer": int or float, either answer_idx (int)
}
Returns:
TGIF-QA score
"""
preds = []
gts = []
# for frameQA
answer_types = []
answer_type2idx = dict(
frameqa={"object": 0, "number": 1, "color": 2, "location": 3},
msrvtt_qa={k: idx for idx, k in enumerate(["what", "who", "how", "where", "when"])},
msvd_qa={k: idx for idx, k in enumerate(["what", "who", "how", "where", "when"])}
)
qid2pred_ans = {r["question_id"]: r["answer"] for r in results}
if self.task_type in self.open_ended_qa_names: # convert ans_idx, int --> str
qid2pred_ans = {k: self.label2ans[v] for k, v in qid2pred_ans.items()}
for qid, pred_ans in qid2pred_ans.items():
preds.append(pred_ans)
gt_data = self.qid2data[qid]
gt_ans = gt_data["answer"]
if self.task_type in self.open_ended_qa_names:
answer_types.append(answer_type2idx[self.task_type][gt_data["answer_type"]])
gts.append(gt_ans)
preds = np.array(preds)
gts = np.array(gts)
metrics = dict()
# preds and gts are array of strings
metrics["overall_acc"] = float(np.mean(preds == gts))
if self.task_type in self.open_ended_qa_names:
answer_types = np.array(answer_types)
ratios = dict()
for ans_type, ans_type_idx in answer_type2idx[self.task_type].items():
answer_type_mask = answer_types == ans_type_idx
answer_type_corrects = (
preds[answer_type_mask] == gts[answer_type_mask])
metrics[f"{ans_type}_acc"] = float(
np.mean(answer_type_corrects)) if len(answer_type_corrects) != 0 else 0
ratios[f"{ans_type}_ratio"] = [
1. * len(answer_type_corrects) / len(answer_types),
len(answer_type_corrects)]
metrics["ratios"] = ratios
return metrics
class VideoQACollator(object):
def __init__(self, tokenizer, max_length=20, task_type="action", n_options=5):
self.tokenizer = tokenizer
self.max_length = max_length
self.task_type = task_type
self.n_options = n_options
def collate_batch(self, batch):
v_collate = default_collate
visual_inputs = v_collate([d["vid"] for d in batch]) # (B, T, 3, H, W)
# group data
text_examples = flat_list_of_lists([d["examples"] for d in batch])
n_examples_list = [d["n_examples"] for d in batch] # (B, )
# group elements data
# directly concatenate question and option as a single seq.
if self.task_type in ["action", "transition"]:
text_str_list = flat_list_of_lists(
[[d["q_str"] + " " + d["options_str_list"][i] for i in range(self.n_options)]
for d in text_examples]
) # (B * n_options, )
else:
text_str_list = [d["q_str"] for d in text_examples] # (B, )
batch_enc = self.tokenizer.batch_encode_plus(
text_str_list,
max_length=self.max_length,
padding='max_length',
return_tensors="pt",
truncation=True
)
text_input_ids = batch_enc.input_ids # (B, L)
text_input_mask = batch_enc.attention_mask # (B, L)
labels = default_collate([int(d["label"]) for d in text_examples]) \
if text_examples[0]["label"] is not None else None # (B, #ans)
question_ids = [d["question_id"] for d in text_examples]
return dict(
visual_inputs=visual_inputs, # (B, #frm, H, W, C)
text_input_ids=text_input_ids,
text_input_mask=text_input_mask,
question_ids=question_ids,
labels=labels,
n_examples_list=n_examples_list # used to create image feature copies.
)
================================================
FILE: src/datasets/dataset_video_retrieval.py
================================================
import random
import copy
import os
import torch
import numpy as np
from torch.utils.data.dataloader import default_collate
from src.utils.basic_utils import flat_list_of_lists
from src.utils.load_save import LOGGER
from src.datasets.dataset_base import AlproBaseDataset
from src.datasets.randaugment import TemporalConsistentRandomAugment
class AlproVideoRetrievalDataset(AlproBaseDataset):
""" This should work for both train and test (where labels are not available).
datalist: list(tuples) each tuple is (img_id, list(dicts)),
each dict
tokenizer:
max_img_size: int,
max_txt_len: int, max text sequence length, including special tokens.
random_sample_clips: bool, whether using randomly sampled N clips or always use uniformly sampled N clips
"""
def __init__(self, datalist, tokenizer, img_lmdb_dir,
fps=3, num_frm=3, frm_sampling_strategy="rand",
max_img_size=1000, max_txt_len=40, itm_neg_size=1,
ensemble_n_clips=1, random_sample_clips=True,
video_fmt='.mp4', img_db_type='lmdb', is_train=False):
super(AlproVideoRetrievalDataset, self).__init__(
datalist, tokenizer, img_lmdb_dir, img_db_type=img_db_type,
fps=fps, num_frm=num_frm,
frm_sampling_strategy=frm_sampling_strategy,
max_img_size=max_img_size, max_txt_len=max_txt_len)
self.ensemble_n_clips = ensemble_n_clips
self.num_labels = 2
self.itm_neg_size = itm_neg_size
self.random_sample_clips = random_sample_clips
self.id2data = {
d["id"]: d for group in datalist for d in group[1]}
self.is_train = is_train
self.video_fmt = video_fmt
if self.is_train:
self.randaug = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', 'HorizontalFlip'])
else:
self.randaug = None
def __len__(self):
return len(self.datalist)
def __getitem__(self, index):
# skip error videos:
num_retries = 5
for _ in range(num_retries):
vid_id, examples = self.datalist[index] # one video with multiple examples
if self.ensemble_n_clips > 1:
raise NotImplementedError('Do not support multiple clips for now.')
else:
video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt)
vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
# Select a random video if the current video was not able to access.
if vid_frm_array is None:
LOGGER.info(f"Failed to load examples with video: {vid_id}. "
f"Will randomly sample an example as a replacement.")
index = random.randint(0, len(self) - 1)
continue
sampled_examples = []
for e in examples:
s = self._get_single_example(e, index)
if isinstance(s, dict):
sampled_examples.append(s)
else:
sampled_examples.extend(s)
return dict(
vid=vid_frm_array,
examples=sampled_examples,
n_examples=len(sampled_examples) # used to create image feature copies.
)
else:
raise RuntimeError(
f"Failed to fetch video after {num_retries} retries.")
def _get_single_example(self, data, index):
examples = []
text_str = data["txt"]
itm_label = 1 # positive pair
examples.append(dict(
text_str=text_str,
itm_label=itm_label
))
return examples
class VideoRetrievalCollator(object):
def __init__(self, tokenizer, max_length=40):
self.tokenizer = tokenizer
self.max_length = max_length
def collate_batch(self, batch):
# FIXME there is a chance that two captions associated with the same video are batched together. Might need to fix.
v_collate = default_collate
visual_inputs = v_collate([d["vid"] for d in batch]) # (B, T, 3, H, W)
# group data
text_examples = flat_list_of_lists([d["examples"] for d in batch])
n_examples_list = [d["n_examples"] for d in batch] # (B, )
# group elements data
# directly concatenate question and option as a single seq.
text_str_list = [d["text_str"] for d in text_examples] # (B, )
batch_enc = self.tokenizer.batch_encode_plus(
text_str_list,
max_length=self.max_length,
padding='max_length',
return_tensors="pt",
truncation=True
)
text_input_ids = batch_enc.input_ids # (B, L)
text_input_mask = batch_enc.attention_mask # (B, L)
if "itm_label" in text_examples[0]:
itm_labels = default_collate(
[d["itm_label"] for d in text_examples]) # (B, )
else:
itm_labels = None
if "id" in text_examples[0]:
caption_ids = [d["id"] for d in text_examples] # (B, )
else:
caption_ids = None
collated_batch = dict(
visual_inputs=visual_inputs, # (B, #frm, H, W, C)
text_input_ids=text_input_ids,
text_input_mask=text_input_mask,
caption_ids=caption_ids, # list(int), example ids,
labels=itm_labels,
n_examples_list=n_examples_list # used to create image feature copies.
)
if "vid_id" in batch[0] and len(batch) == 1:
collated_batch["vid_id"] = batch[0]["vid_id"]
return collated_batch
class AlproVideoRetrievalEvalDataset(AlproBaseDataset):
""" Sample by video/image, calculate scores between each video with all the text
and loop through all the videos. Each batch will only contain a single video,
but multiple text.
datalist: list(dict), each dict
tokenizer:
max_img_size: int,
max_txt_len: int, max text sequence length, including special tokens.
"""
def __init__(self, datalist, tokenizer, img_lmdb_dir,
fps=3, num_frm=3, frm_sampling_strategy="rand",
max_img_size=1000, max_txt_len=40, ensemble_n_clips=1,
video_fmt='.mp4', img_db_type='lmdb'):
self.ensemble_n_clips = ensemble_n_clips
super(AlproVideoRetrievalEvalDataset, self).__init__(
datalist, tokenizer, img_lmdb_dir,
fps=fps, num_frm=num_frm,
frm_sampling_strategy=frm_sampling_strategy,
max_img_size=max_img_size, max_txt_len=max_txt_len,
img_db_type=img_db_type)
# id is unique id per caption/example
for i, d in enumerate(self.datalist):
assert i == d["id"]
self.gt_cap_id2vid_id = {d["id"]: d["vid_id"] for d in datalist}
self.cap_id2data = {d["id"]: d for d in datalist}
self.batches, self.text_batch = self._prepare_batches_by_video()
self.id2data = {d["id"]: d for d in self.datalist}
self.video_fmt = video_fmt
def __len__(self):
return len(self.batches)
def __getitem__(self, index):
# skip error videos:
batch = dict()
batch["vid_id"] = self.batches[index]["vid_id"] # one video with multiple examples
batch["examples"] = self.text_batch["examples"]
batch["n_examples"] = self.text_batch["n_examples"]
batch["ids"] = self.text_batch["ids"]
if self.ensemble_n_clips > 1:
raise NotImplementedError('Do not support multiple clips for now.')
else:
# if self.is_train and self.random_sample_clips:
vid_id = batch["vid_id"]
video_path = os.path.join(self.img_db_dir, vid_id + self.video_fmt)
vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
batch["vid"] = vid_frm_array
return batch
def _prepare_batches_by_video(self):
"""create batches where each batch contains a single video with multiple text"""
text_list = []
for d in self.datalist:
text_list.append(dict(
text_str=d["txt"],
id=d["id"],
))
text_batch = dict(
vid_id=None,
examples=text_list,
n_examples=len(text_list),
ids=[d["id"] for d in text_list]
)
# make 1000 batches for 1000video x 1000text combinations.
# each batch contains 1video x 1000text
batches = []
for idx, d in enumerate(self.datalist):
#_batch = copy.deepcopy(text_batch)
_batch = dict()
_batch["vid_id"] = d["vid_id"]
batches.append(_batch)
return batches, text_batch
================================================
FILE: src/datasets/randaugment.py
================================================
import cv2
import numpy as np
import torch
## aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
'''
same output as PIL.ImageOps.autocontrast
'''
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
offset = -low * scale
table = np.arange(n_bins) * scale + offset
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
'''
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
'''
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0: return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
'''
like PIL, rotate by degree, not radians
'''
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def horizontal_flip_func(img):
'''
[dxli]
horizontally flip an image.
'''
out = cv2.flip(img, 1)
return out
def solarize_func(img, thresh=128):
'''
same output as PIL.ImageOps.posterize
'''
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
'''
same output as PIL.ImageEnhance.Color
'''
## implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = (
np.float32([
[0.886, -0.114, -0.114],
[-0.587, 0.413, -0.587],
[-0.299, -0.299, 0.701]]) * factor
+ np.float32([[0.114], [0.587], [0.299]])
)
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = np.array([(
el - mean) * factor + mean
for el in range(256)
]).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def brightness_func(img, factor):
'''
same output as PIL.ImageEnhance.Contrast
'''
table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
'''
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
'''
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def posterize_func(img, bits):
'''
same output as PIL.ImageOps.posterize
'''
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
# def cutout_func(img, pad_size, replace=(0, 0, 0)):
# replace = np.array(replace, dtype=np.uint8)
# H, W = img.shape[0], img.shape[1]
# rh, rw = np.random.random(2)
# pad_size = pad_size // 2
# ch, cw = int(rh * H), int(rw * W)
# x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
# y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
# out = img.copy()
# out[x1:x2, y1:y2, :] = replace
# return out
### level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
# if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
# if np.random.random() > 0.5: level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level, )
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level, )
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
# if np.random.random() < 0.5:
# level = -level
return (level, replace_value)
return level_to_args
func_dict = {
'Identity': identity_func,
# 'AutoContrast': autocontrast_func,
'Equalize': equalize_func,
'Rotate': rotate_func,
'Solarize': solarize_func,
'Color': color_func,
'Contrast': contrast_func,
'Brightness': brightness_func,
'Sharpness': sharpness_func,
'ShearX': shear_x_func,
'TranslateX': translate_x_func,
'TranslateY': translate_y_func,
'Posterize': posterize_func,
'ShearY': shear_y_func,
'HorizontalFlip': horizontal_flip_func # [dxli]
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
'Identity': none_level_to_args,
# 'AutoContrast': none_level_to_args,
'Equalize': none_level_to_args,
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
'Solarize': solarize_level_to_args(MAX_LEVEL),
'Color': enhance_level_to_args(MAX_LEVEL),
'Contrast': enhance_level_to_args(MAX_LEVEL),
'Brightness': enhance_level_to_args(MAX_LEVEL),
'Sharpness': enhance_level_to_args(MAX_LEVEL),
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
'TranslateX': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'TranslateY': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'Posterize': posterize_level_to_args(MAX_LEVEL),
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
'HorizontalFlip': none_level_to_args # [dxli]
}
class TemporalConsistentRandomAugment(object):
def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
self.N = N
self.M = M
self.p = p
self.tensor_in_tensor_out = tensor_in_tensor_out
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N, replace=False)
# return [(op, 0.5, self.M) for op in sampled_ops]
return [(op, self.M) for op in sampled_ops]
def __call__(self, frames):
assert frames.shape[-1] == 3, 'Expecting last dimension for 3-channels RGB (b, h, w, c).'
if self.tensor_in_tensor_out:
frames = frames.numpy().astype(np.uint8)
num_frames = frames.shape[0]
ops = num_frames * [self.get_random_ops()]
apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
frames = torch.stack(list(map(self._aug, frames, ops, apply_or_not)), dim=0).float()
return frames
def _aug(self, img, ops, apply_or_not):
for i, (name, level) in enumerate(ops):
if not apply_or_not[i]:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return torch.from_numpy(img)
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
def save_frames_grid(img_array, out_path):
import torch
from torchvision.utils import make_grid
from PIL import Image
if len(img_array.shape) == 3:
img_array = img_array.unsqueeze(0)
elif len(img_array.shape) == 5:
b, t, c, h, w = img_array.shape
img_array = img_array.view(-1, c, h, w)
elif len(img_array.shape) == 4:
pass
else:
raise NotImplementedError('Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored.')
assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only."
grid = make_grid(img_array)
ndarr = grid.permute(1, 2, 0).to('cpu', torch.uint8).numpy()
img = Image.fromarray(ndarr)
img.save(out_path)
def stack(data, dim=0):
shape = data[0].shape # need to handle empty list
shape = shape[:dim] + (len(data),) + shape[dim:]
x = torch.cat(data, dim=dim)
x = x.reshape(shape)
# need to handle case where dim=-1
# which is not handled here yet
# but can be done with transposition
return x
if __name__ == '__main__':
import decord, os
from decord import VideoReader
decord.bridge.set_bridge('torch')
root_dir = '/export/share/dongxuli/data/webvid2m/postprocess/downsampled_videos'
video_id = '1058234725.mp4'
video_path = os.path.join(root_dir, video_id)
vr = VideoReader(video_path)
frames = vr.get_batch([1, 3, 5, 7, 9])
frames = frames
# a = TemporalConsistentRandomAugment(N=2, M=5, augs=['Identity', 'Contrast', 'Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate'])
a = TemporalConsistentRandomAugment(N=1, M=5, augs=['HorizontalFlip'])
print(frames[0].shape)
save_frames_grid(frames.permute(0, 3, 1, 2), 'before.jpg')
after_frames = a(frames)
print(after_frames.shape)
save_frames_grid(after_frames.permute(0, 3, 1, 2), 'after.jpg')
================================================
FILE: src/modeling/alpro_models.py
================================================
import copy
import numpy as np
import torch
import torch.nn.functional as F
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from einops import rearrange, reduce, repeat
from horovod import torch as hvd
from src.modeling.timesformer.vit import TimeSformer
from src.modeling.xbert import (BertEmbeddings, BertEncoder, BertForMaskedLM,
BertLMPredictionHead, BertModel, BertPooler,
BertPreTrainedModel, BertPreTrainingHeads)
from src.utils.basic_utils import load_json, load_jsonl, save_frames_grid
from src.utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
class AlproBaseModel(nn.Module):
def __init__(self, config=None, input_format='RGB', video_enc_cfg=None, temp=0.07):
super().__init__()
self.temp = nn.Parameter(torch.ones([]) * temp)
self.bert_config = config
visual_model_cls = eval(video_enc_cfg['cls'])
self.visual_encoder = visual_model_cls(model_cfg=video_enc_cfg, input_format=input_format, cross_attention_config=config)
self.text_encoder = BertForMaskedLM.from_pretrained('bert-base-uncased', config=self.bert_config)
# FIXME make them configurable
embed_dim = 256
vision_width = 768
text_width = self.bert_config.hidden_size
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.itc_token_type = self.bert_config.itc_token_type
self.itm_head = nn.Linear(text_width, 2)
def load_separate_ckpt(self, visual_weights_path=None, bert_weights_path=None):
if visual_weights_path:
self.visual_encoder.load_state_dict(visual_weights_path)
# if bert_weights_path:
# load_multimodal_encoder_state_dict_with_mismatch(self.cross_encoder, bert_weights_path)
# load_mlm_head_state_dict_with_mismatch(self.mlm_head, bert_weights_path)
# def freeze_cnn_backbone(self):
# for n, p in self.visual_encoder.feature.named_parameters():
# p.requires_grad = False
class AlproForPretrain(AlproBaseModel):
def __init__(self, config, video_enc_cfg, input_format='RGB'):
super(AlproForPretrain, self).__init__(config, input_format=input_format, video_enc_cfg=video_enc_cfg)
# model for generating pseudo labels
self.prompter = Prompter(config, video_enc_cfg)
self.use_mask_prob = 0
self.mpm_head = nn.Sequential(
nn.Linear(config.hidden_size,
config.hidden_size * 2),
nn.ReLU(True),
nn.Linear(config.hidden_size * 2, self.prompter.entity_num)
)
def build_text_prompts(self, prompts):
self.prompter.build_text_prompts(prompts)
def get_pseudo_labels(self, batch):
return self.prompter.get_pseudo_labels(batch)
def forward(self, batch):
with torch.no_grad():
self.temp.clamp_(0.001,0.5)
visual_inputs = batch['visual_inputs']
use_mpm = 'mpm_mask' in batch
if use_mpm:
context_visual_inputs = batch['context_visual_inputs']
device = visual_inputs.device
b, t, c, h, w = visual_inputs.shape
# forward image and text features
# feats are normalized embeds
if use_mpm and np.random.uniform() < self.use_mask_prob:
video_embeds_total = self._forward_visual_embeds(torch.cat([visual_inputs, context_visual_inputs], dim=0))
# split for unmasked and masked
video_embeds, context_video_embeds = video_embeds_total[:b], video_embeds_total[b:]
else:
video_embeds = self._forward_visual_embeds(visual_inputs)
context_video_embeds = video_embeds
# we compute normalized feats for unmasked visual inputs only, used for ITC
video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)
video_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)
# text embeddings and features
text_embeds, text_feat = self._forward_text_feats(batch)
# ========== (in-batch) ITC loss ==========
gathered_video_feats = hvd.allgather(video_feat)
gathered_text_feats = hvd.allgather(text_feat)
assert self.itc_token_type == 'cls', 'Support CLS tokens for ITC only, find {}.'.format(self.itc_token_type)
sim_v2t = video_feat @ gathered_text_feats.t() / self.temp
sim_t2v = text_feat @ gathered_video_feats.t() / self.temp
# [IMPORTANT] be very careful when initializing the GT sim_v2t
# allgather return the concatenated features in the order of local_rank()
sim_targets = torch.zeros_like(sim_v2t)
local_rank = hvd.local_rank()
b_start, b_end = b * local_rank, b * (local_rank + 1)
sim_targets[:, b_start: b_end] = torch.eye(b)
loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1)*sim_targets,dim=1).mean()
loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1)*sim_targets,dim=1).mean()
vtc_loss = (loss_v2t+loss_t2v) / 2
# ========= VTM ==========
text_atts = batch['text_input_mask']
# non-masked text and non-masked image
vtm_loss, vtm_logits, vtm_labels, encoder_outputs_pos = self.compute_vtm(text_embeds=text_embeds,
text_atts=text_atts,
video_embeds=video_embeds,
video_atts=video_atts,
sim_v2t=sim_v2t.clone(), # for hard mining
sim_t2v=sim_t2v.clone(), # for hard mining
return_encoder_out=True
)
# ========= MLM ==========
# masked text and non-masked image
if 'mlm_labels' in batch:
mlm_labels = batch['mlm_labels']
mlm_text_input_ids = batch['mlm_text_input_ids']
mlm_loss, mlm_logits, mlm_labels = self.compute_mlm(input_ids=mlm_text_input_ids,
text_input_mask=text_atts,
video_embeds=video_embeds,
video_atts=video_atts,
mlm_labels=mlm_labels
)
else:
mlm_logits = mlm_loss = mlm_labels = None
# ========= MPM ==========
if use_mpm:
mpm_labels, ignore_masks = self.get_pseudo_labels(batch)
mpm_loss, mpm_logits = self.compute_mpm_with_encoder_out(encoder_outputs=encoder_outputs_pos,
text_atts=text_atts,
soft_labels=mpm_labels,
ignore_masks=ignore_masks,
patch_masks=batch['mpm_mask']
)
else:
mpm_loss = mpm_logits = mpm_labels = None
return dict(
itc_loss=vtc_loss,
mlm_scores=mlm_logits, # (B, Lt, vocab_size), only text part
mlm_loss=mlm_loss, # (BxLt)
mlm_labels=mlm_labels, # (1, Lt), with -100 indicates ignored positions
itm_scores=vtm_logits, # (B, 2)
itm_loss=vtm_loss, # (1, )
itm_labels=vtm_labels, # (B, )
mpm_loss=mpm_loss,
mpm_logits=mpm_logits,
mpm_labels=mpm_labels
)
def _forward_visual_embeds(self, visual_inputs):
b, t, c, h, w = visual_inputs.shape
# timeSformer asks for (b, c, t, h, w) as input.
# image features
visual_inputs = visual_inputs.transpose(1, 2)
video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
return video_embeds
def _forward_text_feats(self, batch):
# text features
text_output = self.text_encoder.bert(batch['text_input_ids'],
attention_mask=batch['text_input_mask'],
return_dict = True,
mode = 'text'
)
text_embeds = text_output.last_hidden_state # b, Lt, fsz=768
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)
return text_embeds, text_feat
def compute_mpm_with_encoder_out(self, encoder_outputs, text_atts, soft_labels, ignore_masks, patch_masks):
txt_len = text_atts.shape[1]
# adding one to ignore visual cls tokens
visual_output = encoder_outputs.last_hidden_state[:, txt_len+1:]
bsz, h, w = patch_masks.shape
patch_masks_flatten_inverted = (1 - patch_masks.view(bsz, -1)).unsqueeze(-1)
# mean embeds of masked visual regions
num_masked_patches = torch.sum(patch_masks_flatten_inverted.squeeze(-1), dim=-1, keepdim=True)
masked_visual_embeds = patch_masks_flatten_inverted * visual_output
masked_visual_embeds = torch.sum(masked_visual_embeds, dim=1)
masked_visual_embeds /= num_masked_patches
# loss
mpm_logits = self.mpm_head(masked_visual_embeds)
cross_entropy = -torch.sum(F.log_softmax(mpm_logits, dim=1) * soft_labels, dim=1)
cross_entropy[ignore_masks] = 0.
mpm_loss = torch.sum(cross_entropy) / (bsz - torch.sum(ignore_masks))
return mpm_loss, mpm_logits
def compute_mpm(self, text_embeds, text_atts, image_embeds, image_atts, soft_labels, ignore_masks, patch_masks):
# forward cross-encoder
attention_mask = torch.cat([text_atts, image_atts], dim=1)
embedding_output = torch.cat([text_embeds, image_embeds], dim=1)
encoder_outputs = self.text_encoder.bert(encoder_embeds=embedding_output,
attention_mask=attention_mask,
return_dict=True,
mode='fusion'
)
txt_len = text_atts.shape[1]
# adding one to ignore visual cls tokens
visual_output = encoder_outputs.last_hidden_state[:, txt_len+1:]
bsz, h, w = patch_masks.shape
patch_masks_flatten_inverted = (1 - patch_masks.view(bsz, -1)).unsqueeze(-1)
# mean embeds of masked visual regions
num_masked_patches = torch.sum(patch_masks_flatten_inverted.squeeze(-1), dim=-1, keepdim=True)
masked_visual_embeds = patch_masks_flatten_inverted * visual_output
masked_visual_embeds = torch.sum(masked_visual_embeds, dim=1)
masked_visual_embeds /= num_masked_patches
# loss
mpm_logits = self.mpm_head(masked_visual_embeds)
cross_entropy = -torch.sum(F.log_softmax(mpm_logits, dim=1) * soft_labels, dim=1)
cross_entropy[ignore_masks] = 0.
mpm_loss = torch.sum(cross_entropy) / (bsz - torch.sum(ignore_masks))
return mpm_loss, mpm_logits
def compute_vtm(self, text_embeds, text_atts, video_embeds, video_atts, sim_v2t, sim_t2v, return_encoder_out=False):
device = text_embeds.device
# ====== positive pairs =======
attention_mask = torch.cat([text_atts, video_atts], dim=1)
embedding_output_pos = torch.cat([text_embeds, video_embeds], dim=1)
encoder_outputs_pos = self.text_encoder.bert(encoder_embeds=embedding_output_pos,
attention_mask=attention_mask,
return_dict=True,
mode='fusion'
)
# ====== negative pairs =======
bs = text_embeds.shape[0]
local_rank = hvd.local_rank()
b_start, b_end = bs * local_rank, bs * (local_rank + 1)
with torch.no_grad():
weights_i2t = sim_v2t[:,b_start:b_end]
weights_t2i = sim_t2v[:,b_start:b_end]
# never select self as negative
weights_i2t.fill_diagonal_(-np.Inf)
weights_t2i.fill_diagonal_(-np.Inf)
weights_i2t = F.softmax(weights_i2t, dim=1)
weights_t2i = F.softmax(weights_t2i, dim=1)
# select a negative image for each text
# FIXME to optimize using indexing operations
video_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
video_embeds_neg.append(video_embeds[neg_idx])
video_embeds_neg = torch.stack(video_embeds_neg,dim=0)
# select a negative text for each image
text_embeds_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
text_embeds_neg.append(text_embeds[neg_idx])
text_atts_neg.append(text_atts[neg_idx])
text_embeds_neg = torch.stack(text_embeds_neg,dim=0)
text_atts_neg = torch.stack(text_atts_neg,dim=0)
text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0)
text_atts_all = torch.cat([text_atts, text_atts_neg],dim=0)
video_embeds_all = torch.cat([video_embeds_neg,video_embeds],dim=0)
video_atts_all = torch.cat([video_atts,video_atts],dim=0)
attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1)
embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1)
# forward negative pairs via cross encoder
encoder_outputs_neg = self.text_encoder.bert(encoder_embeds=embedding_output_all,
attention_mask=attention_mask_all,
return_dict=True,
mode='fusion'
)
vl_embeddings = torch.cat([encoder_outputs_pos.last_hidden_state[:,0,:],
encoder_outputs_neg.last_hidden_state[:,0,:]],dim=0)
vtm_logits = self.itm_head(vl_embeddings)
vtm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], dim=0).to(device)
vtm_loss = F.cross_entropy(vtm_logits, vtm_labels)
if return_encoder_out:
return vtm_loss, vtm_logits, vtm_labels, encoder_outputs_pos
else:
return vtm_loss, vtm_logits, vtm_labels, None
def compute_mlm(self, input_ids, text_input_mask, video_embeds, video_atts, mlm_labels):
# forward text features with masked_input_ids
text_output = self.text_encoder.bert(input_ids,
attention_mask=text_input_mask,
return_dict=True,
mode='text'
)
text_embeds = text_output.last_hidden_state
# forward cross-encoder
attention_mask = torch.cat([text_input_mask, video_atts], dim=1)
embedding_output = torch.cat([text_embeds, video_embeds], dim=1)
encoder_outputs = self.text_encoder.bert(encoder_embeds=embedding_output,
attention_mask=attention_mask,
return_dict=True,
mode='fusion'
)
txt_len = text_input_mask.shape[1]
txt_output = encoder_outputs.last_hidden_state[:, :txt_len]
mlm_logits = self.text_encoder.cls(txt_output)
loss_fct = CrossEntropyLoss()
mlm_loss = loss_fct(mlm_logits.view(-1, self.bert_config.vocab_size), mlm_labels.view(-1))
return mlm_loss, mlm_logits, mlm_labels
def load_separate_ckpt(self, visual_weights_path=None, bert_weights_path=None, prompter_weights_path=None):
if visual_weights_path:
self.visual_encoder.load_state_dict(visual_weights_path)
# [NOTE] BERT is initialized from huggingface pre-trained weights.
# if bert_weights_path:
# load_multimodal_encoder_state_dict_with_mismatch(self.cross_encoder, bert_weights_path)
# load_mlm_head_state_dict_with_mismatch(self.mlm_head, bert_weights_path)
# TODO make path configurable
if prompter_weights_path is not None:
self.prompter.load_pretrained_weights_without_prompts(prompter_weights_path)
class Prompter(AlproBaseModel):
def __init__(self, config, video_enc_cfg, input_format='RGB'):
super(Prompter, self).__init__(config, input_format=input_format, video_enc_cfg=video_enc_cfg)
# self.entity_num = 1000
self.entity_num = config.num_entities
self.register_buffer("video_prompt_feat", torch.rand(self.entity_num, 256))
self.register_buffer("image_prompt_feat", torch.rand(self.entity_num, 256))
self.prompt_initialized = False
# if the prob for the most likely entity is < 0.2, we just ignore it
self.ignore_threshold = 0.2
def load_pretrained_weights_without_prompts(self, ckpt_path):
LOGGER.info("Loading weights for teacher model.")
loaded_state_dict = torch.load(ckpt_path, map_location='cpu')
loaded_keys = loaded_state_dict.keys()
model_keys = self.state_dict().keys()
load_not_in_model = [k for k in loaded_keys if k not in model_keys]
model_not_in_load = [k for k in model_keys if k not in loaded_keys]
if hvd.rank() == 0:
LOGGER.info("Keys in loaded but not in model:")
LOGGER.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}")
LOGGER.info("Keys in model but not in loaded:")
LOGGER.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}")
# FIXME a quick hack to avoid loading prompts
new_loaded_state_dict = dict()
for k in loaded_state_dict:
if not 'prompt_feat' in k:
new_loaded_state_dict[k] = loaded_state_dict[k]
loaded_state_dict = new_loaded_state_dict
self.load_state_dict(loaded_state_dict, strict=False)
def build_text_prompts(self, prompts):
"""
This function will be called, if no e2e.weights is provided.
In that case,
"""
assert not self.prompt_initialized, "Repetitively building prompts?"
if self.training:
self.eval()
video_prompt_feat_all = []
image_prompt_feat_all = []
with torch.no_grad():
# this configurable depending on the GPU memory limit
step_size = 10000
# ====== initializing video prompting ======
b_video, _ = prompts['batch_enc_video_prompts'].input_ids.shape
start = 0
end = start + step_size
while start < b_video:
video_prompt_output = self.text_encoder.bert(prompts['batch_enc_video_prompts'].input_ids[start:end].cuda(),
attention_mask=prompts['batch_enc_video_prompts'].attention_mask[start:end].cuda(),
return_dict=True,
mode='text'
)
video_prompt_embeds = video_prompt_output.last_hidden_state # b, Lt, fsz=768
video_prompt_feat = F.normalize(self.text_proj(video_prompt_embeds[:,0,:]),dim=-1)
# collecting
video_prompt_feat_all.append(video_prompt_feat)
start += step_size
end += step_size
# average ensembling
video_prompt_feat = torch.cat(video_prompt_feat_all, dim=0)
video_num_templates = int(video_prompt_feat.shape[0] / self.entity_num)
video_prompt_feat = torch.stack(video_prompt_feat.chunk(video_num_templates), dim=1)
video_prompt_feat = torch.mean(video_prompt_feat, dim=1)
self.video_prompt_feat = video_prompt_feat
# ====== initializing image prompting ======
b_image, _ = prompts['batch_enc_image_prompts'].input_ids.shape
start = 0
end = start + step_size
while start < b_image:
# image prompts
image_prompt_output = self.text_encoder.bert(prompts['batch_enc_image_prompts'].input_ids[start:end].cuda(),
attention_mask=prompts['batch_enc_image_prompts'].attention_mask[start:end].cuda(),
return_dict = True,
mode = 'text'
)
image_prompt_embeds = image_prompt_output.last_hidden_state # b, Lt, fsz=768
image_prompt_feat = F.normalize(self.text_proj(image_prompt_embeds[:,0,:]),dim=-1)
# collecting
image_prompt_feat_all.append(image_prompt_feat)
start += step_size
end += step_size
image_prompt_feat = torch.cat(image_prompt_feat_all, dim=0)
image_num_templates = int(image_prompt_feat.shape[0] / self.entity_num)
image_prompt_feat = torch.stack(image_prompt_feat.chunk(image_num_templates), dim=1)
image_prompt_feat = torch.mean(image_prompt_feat, dim=1)
self.image_prompt_feat = image_prompt_feat
self.prompt_initialized = True
def _forward_visual_embeds(self, visual_inputs):
b, t, c, h, w = visual_inputs.shape
# timeSformer asks for (b, c, t, h, w) as input.
# image features
visual_inputs = visual_inputs.transpose(1, 2)
video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
assert self.itc_token_type == 'cls', 'Expecting CLS token for ITC, found {}'.format(self.itc_token_type)
if self.itc_token_type == 'cls':
video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)
else:
raise NotImplementedError("itc_type_type must be one of ['mean', 'cls', 'mil'], found {}".format(self.itc_token_type))
return video_embeds, video_feat
def _compute_soft_labels(self, sim_vp_masked):
soft_labels = nn.Softmax(dim=1)(sim_vp_masked)
ignore_masks = torch.max(sim_vp_masked, dim=1)[1] < self.ignore_threshold
return soft_labels, ignore_masks
def get_pseudo_labels(self, batch):
if self.training:
self.eval()
with torch.no_grad():
masked_visual_inputs = batch['crop_visual_inputs']
_, masked_image_feat = self._forward_visual_embeds(masked_visual_inputs)
if batch['type'] == 'video':
prompt_feat = self.video_prompt_feat
else:
prompt_feat = self.image_prompt_feat
# visual feat to video prompts
# masked visual feat to video prompts
sim_masked = masked_image_feat @ prompt_feat.t() / self.temp
pseudo_labels, ignore_masks = self._compute_soft_labels(sim_masked)
return pseudo_labels, ignore_masks
def forward(self, batch):
visual_inputs = batch['visual_inputs']
device = visual_inputs.device
b, t, c, h, w = visual_inputs.shape
# forward image and text features
# feats are normalized embeds
video_embeds, video_feat, text_embeds, text_feat = self.forward_feats(batch)
image_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)
# ========== (in-batch) ITC loss ==========
gathered_image_feats = hvd.allgather(video_feat)
gathered_text_feats = hvd.allgather(text_feat)
assert self.itc_token_type == 'cls', 'Expecting CLS token for ITC, found {}'.format(self.itc_token_type)
sim_v2t = video_feat @ gathered_text_feats.t() / self.temp
sim_t2v = text_feat @ gathered_image_feats.t() / self.temp
# [IMPORTANT] be very careful when initializing the GT sim_i2t
# allgather return the concatenated features in the order of local_rank()
sim_targets = torch.zeros_like(sim_v2t)
local_rank = hvd.local_rank()
b_start, b_end = b * local_rank, b * (local_rank + 1)
sim_targets[:, b_start: b_end] = torch.eye(b)
sim_v2t_scores = F.log_softmax(sim_v2t, dim=1)
sim_t2v_scores = F.log_softmax(sim_t2v, dim=1)
loss_v2t = -torch.sum(sim_v2t_scores * sim_targets,dim=1).mean()
loss_t2v = -torch.sum(sim_t2v_scores * sim_targets,dim=1).mean()
vtc_loss = (loss_v2t+loss_t2v) / 2
return dict(
itc_loss=vtc_loss,
itc_labels=torch.max(sim_targets, dim=1)[1],
i2t_scores=sim_v2t_scores,
t2i_scores=sim_t2v_scores
)
def forward_feats(self, batch):
with torch.no_grad():
self.temp.clamp_(0.001,0.5)
visual_inputs = batch['visual_inputs']
b, t, c, h, w = visual_inputs.shape
# timeSformer asks for (b, c, t, h, w) as input.
# image features
visual_inputs = visual_inputs.transpose(1, 2)
video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
assert self.itc_token_type == 'cls', 'Expecting CLS token for ITC, found {}'.format(self.itc_token_type)
if self.itc_token_type == 'cls':
video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)
else:
raise NotImplementedError("itc_type_type must be one of ['mean', 'cls', 'mil'], found {}".format(self.itc_token_type))
# text features
text_output = self.text_encoder.bert(batch['text_input_ids'],
attention_mask=batch['text_input_mask'],
return_dict = True,
mode = 'text'
)
text_embeds = text_output.last_hidden_state # b, Lt, fsz=768
if self.itc_token_type == 'cls':
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)
else:
raise NotImplementedError("itc_token_type must be one of ['mean', 'cls', 'mil'], found {}".format(self.itc_token_type))
return video_embeds, video_feat, text_embeds, text_feat
class AlproForSequenceClassification(AlproBaseModel):
def __init__(self, config, video_enc_cfg, input_format='RGB'):
super(AlproForSequenceClassification, self).__init__(config, video_enc_cfg=video_enc_cfg)
self.text_encoder = BertModel.from_pretrained('bert-base-uncased', config=self.bert_config, add_pooling_layer=False)
self.classifier = nn.Sequential(
nn.Linear(config.hidden_size,
config.hidden_size * 2),
nn.ReLU(True),
nn.Linear(config.hidden_size * 2, config.num_labels)
)
# def forward(self, image, text, targets, alpha=0, train=True):
def forward(self, batch):
visual_inputs = batch['visual_inputs']
targets = batch['labels']
device = visual_inputs.device
# forward text
text_input_mask = batch['text_input_mask']
text_output = self.text_encoder(batch['text_input_ids'],
attention_mask=text_input_mask,
return_dict=True,
mode='text'
)
text_embeds = text_output.last_hidden_state
# forward visual
b, t, c, h, w = visual_inputs.shape
# timeSformer asks for (b, c, t, h, w) as input.
visual_inputs = visual_inputs.transpose(1, 2)
image_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(device)
# forward cross-encoder
attention_mask = torch.cat([text_input_mask, image_atts], dim=1)
embedding_output = torch.cat([text_embeds, image_embeds], dim=1)
output = self.text_encoder(encoder_embeds=embedding_output,
attention_mask=attention_mask,
return_dict=True,
mode='fusion'
)
prediction = self.classifier(output.last_hidden_state[:,0,:])
if targets is not None:
loss = F.cross_entropy(prediction, targets)
else: # evaluation mode
loss = 0
return dict(loss=loss,
logits=prediction
)
def forward_inference(self, batch):
visual_inputs = batch['visual_inputs']
device = visual_inputs.device
# forward text
text_input_mask = batch['text_input_mask']
text_output = self.text_encoder.bert(batch['text_input_ids'],
attention_mask=text_input_mask,
return_dict=True,
mode='text'
)
text_embeds = text_output.last_hidden_state
# forward visual
b, t, c, h, w = visual_inputs.shape
# timeSformer asks for (b, c, t, h, w) as input.
visual_inputs = visual_inputs.transpose(1, 2)
image_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(device)
# forward cross-encoder
attention_mask = torch.cat([text_input_mask, image_atts], dim=1)
embedding_output = torch.cat([text_embeds, image_embeds], dim=1)
output = self.text_encoder.bert(encoder_embeds=embedding_output,
attention_mask=attention_mask,
return_dict=True,
mode='fusion'
)
prediction = self.classifier(output.last_hidden_state[:,0,:])
return prediction
class AlproForVideoTextRetrieval(AlproBaseModel):
"""
"""
def __init__(self, config, video_enc_cfg, input_format='RGB'):
super(AlproForVideoTextRetrieval, self).__init__(config, input_format=input_format, video_enc_cfg=video_enc_cfg)
def forward(self, batch):
with torch.no_grad():
self.temp.clamp_(0.001,0.5)
visual_inputs = batch['visual_inputs']
text_input_mask = batch['text_input_mask']
text_input_ids = batch['text_input_ids']
device = visual_inputs.device
b, t, c, h, w = visual_inputs.shape
# timeSformer asks for (b, c, t, h, w) as input.
# visual embeddings
visual_inputs = visual_inputs.transpose(1, 2)
video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
# image_embeds = image_embeds.repeat(text_input_mask.shape[0], 1, 1)
video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)
video_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)
# text embeddings
text_output = self.text_encoder.bert(text_input_ids,
attention_mask=text_input_mask,
return_dict=True,
mode='text'
)
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)
# ========== (in-batch) ITC loss ==========
gathered_video_feats = hvd.allgather(video_feat)
gathered_text_feats = hvd.allgather(text_feat)
sim_v2t = video_feat @ gathered_text_feats.t() / self.temp
sim_t2v = text_feat @ gathered_video_feats.t() / self.temp
sim_targets = torch.zeros_like(sim_v2t)
local_rank = hvd.local_rank()
b_start, b_end = b * local_rank, b * (local_rank + 1)
sim_targets[:, b_start: b_end] = torch.eye(b)
loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1)*sim_targets,dim=1).mean()
loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1)*sim_targets,dim=1).mean()
vtc_loss = (loss_v2t+loss_t2v) / 2
# ========= ITM ==========
text_atts = batch['text_input_mask']
# non-masked text and non-masked image
vtm_loss, vtm_logits, vtm_labels = self.compute_vtm(text_embeds=text_embeds,
text_atts=text_atts,
image_embeds=video_embeds,
image_atts=video_atts,
sim_i2t=sim_v2t.clone(), # for hard mining
sim_t2i=sim_t2v.clone() # for hard mining
)
return dict(
itm_scores=vtm_logits,
itm_loss=vtm_loss,
itm_labels=vtm_labels,
itc_loss=vtc_loss
)
def compute_vtm(self, text_embeds, text_atts, image_embeds, image_atts, sim_i2t, sim_t2i):
device = text_embeds.device
# ====== positive pairs =======
attention_mask = torch.cat([text_atts, image_atts], dim=1)
embedding_output_pos = torch.cat([text_embeds, image_embeds], dim=1)
encoder_outputs_pos = self.text_encoder.bert(encoder_embeds=embedding_output_pos,
attention_mask=attention_mask,
return_dict=True,
mode='fusion'
)
# ====== negative pairs =======
bs = text_embeds.shape[0]
local_rank = hvd.local_rank()
b_start, b_end = bs * local_rank, bs * (local_rank + 1)
with torch.no_grad():
weights_v2t = sim_i2t[:,b_start:b_end]
weights_t2v = sim_t2i[:,b_start:b_end]
# never select self as negative
weights_v2t.fill_diagonal_(-np.Inf)
weights_t2v.fill_diagonal_(-np.Inf)
weights_v2t = F.softmax(weights_v2t, dim=1)
weights_t2v = F.softmax(weights_t2v, dim=1)
# select a negative image for each text
# FIXME to optimize using indexing operations
image_embeds_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_t2v[b], 1).item()
image_embeds_neg.append(image_embeds[neg_idx])
image_embeds_neg = torch.stack(image_embeds_neg,dim=0)
# select a negative text for each image
text_embeds_neg = []
text_atts_neg = []
for b in range(bs):
neg_idx = torch.multinomial(weights_v2t[b], 1).item()
text_embeds_neg.append(text_embeds[neg_idx])
text_atts_neg.append(text_atts[neg_idx])
text_embeds_neg = torch.stack(text_embeds_neg,dim=0)
text_atts_neg = torch.stack(text_atts_neg,dim=0)
text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0)
text_atts_all = torch.cat([text_atts, text_atts_neg],dim=0)
video_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0)
video_atts_all = torch.cat([image_atts,image_atts],dim=0)
attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1)
embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1)
# forward negative pairs via cross encoder
encoder_outputs_neg = self.text_encoder.bert(encoder_embeds=embedding_output_all,
attention_mask=attention_mask_all,
return_dict=True,
mode='fusion'
)
vl_embeddings = torch.cat([encoder_outputs_pos.last_hidden_state[:,0,:],
encoder_outputs_neg.last_hidden_state[:,0,:]],dim=0)
vtm_logits = self.itm_head(vl_embeddings)
vtm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], dim=0).to(device)
vtm_loss = F.cross_entropy(vtm_logits, vtm_labels)
return vtm_loss, vtm_logits, vtm_labels
def forward_inference(self, batch):
visual_inputs = batch['visual_inputs']
text_input_mask = batch['text_input_mask']
text_input_ids = batch['text_input_ids']
device = visual_inputs.device
b, t, c, h, w = visual_inputs.shape
# timeSformer asks for (b, c, t, h, w) as input.
visual_inputs = visual_inputs.transpose(1, 2)
video_embeds = self.visual_encoder.forward_features(visual_inputs, return_all_tokens=True)
video_feat = F.normalize(self.vision_proj(video_embeds[:,0,:]),dim=-1)
video_embeds = video_embeds.repeat(text_input_mask.shape[0], 1, 1)
# image_feat = image_feat.repeat(text_input_mask.shape[0], 1)
video_atts = torch.ones(video_embeds.size()[:-1],dtype=torch.long).to(device)
text_output = self.text_encoder.bert(text_input_ids,
attention_mask=text_input_mask,
return_dict=True,
mode='text'
)
text_embeds = text_output.last_hidden_state
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1)
vtc_sim_scores = video_feat @ text_feat.t() / self.temp
attention_mask = torch.cat([text_input_mask, video_atts], dim=1)
embedding_output = torch.cat([text_embeds, video_embeds], dim=1)
encoder_outputs = self.text_encoder.bert(encoder_embeds=embedding_output,
attention_mask=attention_mask,
return_dict=True,
mode='fusion'
)
vl_embeddings = encoder_outputs.last_hidden_state[:,0,:]
logits = self.itm_head(vl_embeddings)
return dict(logits=logits, itc_scores=vtc_sim_scores)
================================================
FILE: src/modeling/timesformer/__init__.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# from .build import MODEL_REGISTRY, build_model # noqa
# from .custom_video_model_builder import * # noqa
# from .video_model_builder import ResNet, SlowFast # noqa
================================================
FILE: src/modeling/timesformer/conv2d_same.py
================================================
# Copyright 2020 Ross Wightman
# Conv2d w/ Same Padding
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
import math
from typing import List, Tuple
#from .padding import pad_same, get_padding_value
# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
return x
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = get_padding(kernel_size, **kwargs)
return padding, dynamic
def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
x = pad_same(x, weight.shape[-2:], stride, dilation)
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
================================================
FILE: src/modeling/timesformer/features.py
================================================
# Copyright 2020 Ross Wightman
from collections import OrderedDict, defaultdict
from copy import deepcopy
from functools import partial
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
class FeatureInfo:
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
prev_reduction = 1
for fi in feature_info:
# sanity check the mandatory fields, there may be additional fields depending on the model
assert 'num_chs' in fi and fi['num_chs'] > 0
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
prev_reduction = fi['reduction']
assert 'module' in fi
self.out_indices = out_indices
self.info = feature_info
def from_other(self, out_indices: Tuple[int]):
return FeatureInfo(deepcopy(self.info), out_indices)
def get(self, key, idx=None):
""" Get value by key at specified index (indices)
if idx == None, returns value for key at each output index
if idx is an integer, return value for that feature module index (ignoring output indices)
if idx is a list/tupple, return value for each module index (ignoring output indices)
"""
if idx is None:
return [self.info[i][key] for i in self.out_indices]
if isinstance(idx, (tuple, list)):
return [self.info[i][key] for i in idx]
else:
return self.info[idx][key]
def get_dicts(self, keys=None, idx=None):
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
"""
if idx is None:
if keys is None:
return [self.info[i] for i in self.out_indices]
else:
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
if isinstance(idx, (tuple, list)):
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
else:
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
def channels(self, idx=None):
""" feature channels accessor
"""
return self.get('num_chs', idx)
def reduction(self, idx=None):
""" feature reduction (output stride) accessor
"""
return self.get('reduction', idx)
def module_name(self, idx=None):
""" feature module name accessor
"""
return self.get('module', idx)
def __getitem__(self, item):
return self.info[item]
def __len__(self):
return len(self.info)
class FeatureHooks:
""" Feature Hook Helper
This module helps with the setup and extraction of hooks for extracting features from
internal nodes in a model by node name. This works quite well in eager Python but needs
redesign for torcscript.
"""
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
# setup feature hooks
modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks):
hook_name = h['module']
m = modules[hook_name]
hook_id = out_map[i] if out_map else hook_name
hook_fn = partial(self._collect_output_hook, hook_id)
hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
if hook_type == 'forward_pre':
m.register_forward_pre_hook(hook_fn)
elif hook_type == 'forward':
m.register_forward_hook(hook_fn)
else:
assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
if isinstance(x, tuple):
x = x[0] # unwrap input tuple
self._feature_outputs[x.device][hook_id] = x
def get_output(self, device) -> Dict[str, torch.tensor]:
output = self._feature_outputs[device]
self._feature_outputs[device] = OrderedDict() # clear after reading
return output
def _module_list(module, flatten_sequential=False):
# a yield/iter would be better for this but wouldn't be compatible with torchscript
ml = []
for name, module in module.named_children():
if flatten_sequential and isinstance(module, nn.Sequential):
# first level of Sequential containers is flattened into containing model
for child_name, child_module in module.named_children():
combined = [name, child_name]
ml.append(('_'.join(combined), '.'.join(combined), child_module))
else:
ml.append((name, name, module))
return ml
def _get_feature_info(net, out_indices):
feature_info = getattr(net, 'feature_info')
if isinstance(feature_info, FeatureInfo):
return feature_info.from_other(out_indices)
elif isinstance(feature_info, (list, tuple)):
return FeatureInfo(net.feature_info, out_indices)
else:
assert False, "Provided feature_info is not valid"
def _get_return_layers(feature_info, out_map):
module_names = feature_info.module_name()
return_layers = {}
for i, name in enumerate(module_names):
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
return return_layers
class FeatureDictNet(nn.ModuleDict):
""" Feature extractor with OrderedDict return
Wrap a model and extract features as specified by the out indices, the network is
partially re-built from contained modules.
There is a strong assumption that the modules have been registered into the model in the same
order as they are used. There should be no reuse of the same nn.Module more than once, including
trivial modules like `self.relu = nn.ReLU`.
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model from which we will extract the features
out_indices (tuple[int]): model output indices to extract features for
out_map (sequence): list or tuple specifying desired return id for each out index,
otherwise str(index) is used
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
"""
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
super(FeatureDictNet, self).__init__()
self.feature_info = _get_feature_info(model, out_indices)
self.concat = feature_concat
self.return_layers = {}
return_layers = _get_return_layers(self.feature_info, out_map)
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys())
layers = OrderedDict()
for new_name, old_name, module in modules:
layers[new_name] = module
if old_name in remaining:
# return id has to be consistently str type for torchscript
self.return_layers[new_name] = str(return_layers[old_name])
remaining.remove(old_name)
if not remaining:
break
assert not remaining and len(self.return_layers) == len(return_layers), \
f'Return layers ({remaining}) are not present in model'
self.update(layers)
def _collect(self, x) -> (Dict[str, torch.Tensor]):
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_id = self.return_layers[name]
if isinstance(x, (tuple, list)):
# If model tap is a tuple or list, concat or select first element
# FIXME this may need to be more generic / flexible for some nets
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
else:
out[out_id] = x
return out
def forward(self, x) -> Dict[str, torch.Tensor]:
return self._collect(x)
class FeatureListNet(FeatureDictNet):
""" Feature extractor with list return
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
"""
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
super(FeatureListNet, self).__init__(
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
flatten_sequential=flatten_sequential)
def forward(self, x) -> (List[torch.Tensor]):
return list(self._collect(x).values())
class FeatureHookNet(nn.ModuleDict):
""" FeatureHookNet
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
network in any way.
If `no_rewrite` is False, the model will be re-written as in the
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
FIXME this does not currently work with Torchscript, see FeatureHooks class
"""
def __init__(
self, model,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
super(FeatureHookNet, self).__init__()
assert not torch.jit.is_scripting()
self.feature_info = _get_feature_info(model, out_indices)
self.out_as_dict = out_as_dict
layers = OrderedDict()
hooks = []
if no_rewrite:
assert not flatten_sequential
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
model.reset_classifier(0)
layers['body'] = model
hooks.extend(self.feature_info.get_dicts())
else:
modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info.get_dicts()}
for new_name, old_name, module in modules:
layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name):
if fn in remaining:
hooks.append(dict(module=fn, hook_type=remaining[fn]))
del remaining[fn]
if not remaining:
break
assert not remaining, f'Return layers ({remaining}) are not present in model'
self.update(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
def forward(self, x):
for name, module in self.items():
x = module(x)
out = self.hooks.get_output(x.device)
return out if self.out_as_dict else list(out.values())
================================================
FILE: src/modeling/timesformer/helpers.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright 2020 Ross Wightman
# Modified model creation / weight loading / state_dict helpers
import logging
import os
import sys
import math
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
from src.modeling.timesformer.features import FeatureListNet, FeatureDictNet, FeatureHookNet
from src.modeling.timesformer.conv2d_same import Conv2dSame
from src.modeling.timesformer.linear import Linear
from horovod import torch as hvd
_logger = logging.getLogger()
def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = 'state_dict'
if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
# strip `module.` prefix
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
state_dict = new_state_dict
elif 'model_state' in checkpoint:
state_dict_key = 'model_state'
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
# strip `model.` prefix
name = k[6:] if k.startswith('model') else k
new_state_dict[name] = v
state_dict = new_state_dict
else:
state_dict = checkpoint
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
return state_dict
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict)
# def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
# resume_epoch = None
# if os.path.isfile(checkpoint_path):
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
# if log_info:
# _logger.info('Restoring model state from checkpoint...')
# new_state_dict = OrderedDict()
# for k, v in checkpoint['state_dict'].items():
# name = k[7:] if k.startswith('module') else k
# new_state_dict[name] = v
# model.load_state_dict(new_state_dict)
# if optimizer is not None and 'optimizer' in checkpoint:
# if log_info:
# _logger.info('Restoring optimizer state from checkpoint...')
# optimizer.load_state_dict(checkpoint['optimizer'])
# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
# if log_info:
# _logger.info('Restoring AMP loss scaler state from checkpoint...')
# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
# if 'epoch' in checkpoint:
# resume_epoch = checkpoint['epoch']
# if 'version' in checkpoint and checkpoint['version'] > 1:
# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
# if log_info:
# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
# else:
# model.load_state_dict(checkpoint)
# if log_info:
# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
# return resume_epoch
# else:
# _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
# raise FileNotFoundError()
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, img_size=224, num_frames=8, num_patches=196, attention_type='divided_space_time', pretrained_model="", strict=True):
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL is invalid, using random initialization.")
return
if len(pretrained_model) == 0:
if cfg is None:
_logger.info(f"loading from default config {model.default_cfg}.")
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
else:
try:
state_dict = load_state_dict(pretrained_model)['model']
except:
state_dict = load_state_dict(pretrained_model)
if filter_fn is not None:
state_dict = filter_fn(state_dict)
if in_chans == 1:
conv1_name = cfg['first_conv']
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
conv1_weight = state_dict[conv1_name + '.weight']
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I > 3:
assert conv1_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
else:
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
elif in_chans != 3:
conv1_name = cfg['first_conv']
conv1_weight = state_dict[conv1_name + '.weight']
conv1_type = conv1_weight.dtype
conv1_weight = conv1_weight.float()
O, I, J, K = conv1_weight.shape
if I != 3:
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
del state_dict[conv1_name + '.weight']
strict = False
else:
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
repeat = int(math.ceil(in_chans / 3))
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
conv1_weight *= (3 / float(in_chans))
conv1_weight = conv1_weight.to(conv1_type)
state_dict[conv1_name + '.weight'] = conv1_weight
classifier_name = cfg['classifier']
if num_classes == 1000 and cfg['num_classes'] == 1001:
# special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
elif num_classes != state_dict[classifier_name + '.weight'].size(0):
#print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True)
# completely discard fully connected for all other differences between pretrained and created model
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
strict = False
## Resizing the positional embeddings in case they don't match
_logger.info(f"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}")
if num_patches + 1 != state_dict['pos_embed'].size(1):
pos_embed = state_dict['pos_embed']
cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1)
other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2)
new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')
new_pos_embed = new_pos_embed.transpose(1, 2)
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
state_dict['pos_embed'] = new_pos_embed
## Resizing time embeddings in case they don't match
if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1):
_logger.info(f"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}")
time_embed = state_dict['time_embed'].transpose(1, 2)
new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest')
state_dict['time_embed'] = new_time_embed.transpose(1, 2)
## Initializing temporal attention
if attention_type == 'divided_space_time':
new_state_dict = state_dict.copy()
for key in state_dict:
if 'blocks' in key and 'attn' in key:
new_key = key.replace('attn','temporal_attn')
if not new_key in state_dict:
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[new_key] = state_dict[new_key]
if 'blocks' in key and 'norm1' in key:
new_key = key.replace('norm1','temporal_norm1')
if not new_key in state_dict:
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[new_key] = state_dict[new_key]
state_dict = new_state_dict
## Loading the weights
model.load_state_dict(state_dict, strict=False)
def load_pretrained_CLIP_ViT(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs):
if hvd.rank() == 0:
_logger.info(f"Loading CLIP ViT-B/16 checkpoints.")
loaded_state_dict = torch.load(pretrained_model)
## Initializing temporal attention
new_state_dict = loaded_state_dict.copy()
for key in loaded_state_dict:
if 'blocks' in key and 'attn' in key:
new_key = key.replace('attn','temporal_attn')
if not new_key in loaded_state_dict:
new_state_dict[new_key] = loaded_state_dict[key]
else:
new_state_dict[new_key] = loaded_state_dict[new_key]
if 'blocks' in key and 'norm1' in key:
new_key = key.replace('norm1','temporal_norm1')
if not new_key in loaded_state_dict:
new_state_dict[new_key] = loaded_state_dict[key]
else:
new_state_dict[new_key] = loaded_state_dict[new_key]
loaded_state_dict = new_state_dict
loaded_keys = loaded_state_dict.keys()
model_keys = model.state_dict().keys()
load_not_in_model = [k for k in loaded_keys if k not in model_keys]
model_not_in_load = [k for k in model_keys if k not in loaded_keys]
toload = dict()
mismatched_shape_keys = []
for k in model_keys:
if k in loaded_keys:
if model.state_dict()[k].shape != loaded_state_dict[k].shape:
mismatched_shape_keys.append(k)
else:
toload[k] = loaded_state_dict[k]
if hvd.rank() == 0:
_logger.info("Keys in loaded but not in model:")
_logger.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}")
_logger.info("Keys in model but not in loaded:")
_logger.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}")
_logger.info("Keys in model and loaded, but shape mismatched:")
_logger.info(f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}")
model.load_state_dict(toload, strict=False)
def load_pretrained_imagenet(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs):
import timm
if hvd.rank() == 0:
_logger.info(f"Loading vit_base_patch16_224 checkpoints.")
loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224(pretrained=True).state_dict()
del loaded_state_dict['head.weight']
del loaded_state_dict['head.bias']
## Initializing temporal attention
new_state_dict = loaded_state_dict.copy()
for key in loaded_state_dict:
if 'blocks' in key and 'attn' in key:
new_key = key.replace('attn','temporal_attn')
if not new_key in loaded_state_dict:
new_state_dict[new_key] = loaded_state_dict[key]
else:
new_state_dict[new_key] = loaded_state_dict[new_key]
if 'blocks' in key and 'norm1' in key:
new_key = key.replace('norm1','temporal_norm1')
if not new_key in loaded_state_dict:
new_state_dict[new_key] = loaded_state_dict[key]
else:
new_state_dict[new_key] = loaded_state_dict[new_key]
loaded_state_dict = new_state_dict
loaded_keys = loaded_state_dict.keys()
model_keys = model.state_dict().keys()
load_not_in_model = [k for k in loaded_keys if k not in model_keys]
model_not_in_load = [k for k in model_keys if k not in loaded_keys]
toload = dict()
mismatched_shape_keys = []
for k in model_keys:
if k in loaded_keys:
if model.state_dict()[k].shape != loaded_state_dict[k].shape:
mismatched_shape_keys.append(k)
else:
toload[k] = loaded_state_dict[k]
if hvd.rank() == 0:
_logger.info("Keys in loaded but not in model:")
_logger.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}")
_logger.info("Keys in model but not in loaded:")
_logger.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}")
_logger.info("Keys in model and loaded, but shape mismatched:")
_logger.info(f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}")
model.load_state_dict(toload, strict=False)
def load_pretrained_kinetics(model, pretrained_model, cfg=None, ignore_classifier=True, num_frames=8, num_patches=196, **kwargs):
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL is invalid, using random initialization.")
return
assert len(pretrained_model) > 0, "Path to pre-trained Kinetics weights not provided."
state_dict = load_state_dict(pretrained_model)
classifier_name = cfg['classifier']
if ignore_classifier:
classifier_weight_key = classifier_name + '.weight'
classifier_bias_key = classifier_name + '.bias'
state_dict[classifier_weight_key] = model.state_dict()[classifier_weight_key]
state_dict[classifier_bias_key] = model.state_dict()[classifier_bias_key]
else:
raise NotImplementedError('[dxli] Not supporting loading Kinetics-pretrained ckpt with classifier.')
## Resizing the positional embeddings in case they don't match
if num_patches + 1 != state_dict['pos_embed'].size(1):
new_pos_embed = resize_spatial_embedding(state_dict, 'pos_embed', num_patches)
state_dict['pos_embed'] = new_pos_embed
## Resizing time embeddings in case they don't match
if 'time_embed' in state_dict and num_frames != state_dict['time_embed'].size(1):
state_dict['time_embed'] = resize_temporal_embedding(state_dict, 'time_embed', num_frames)
## Loading the weights
try:
model.load_state_dict(state_dict, strict=True)
_logger.info('Succeeded in loading Kinetics pre-trained weights.')
except:
_logger.error('Error in loading Kinetics pre-trained weights.')
def resize_spatial_embedding(state_dict, key, num_patches):
_logger.info(f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}")
pos_embed = state_dict[key]
cls_pos_embed = pos_embed[0,0,:].unsqueeze(0).unsqueeze(1)
other_pos_embed = pos_embed[0,1:,:].unsqueeze(0).transpose(1, 2)
new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode='nearest')
new_pos_embed = new_pos_embed.transpose(1, 2)
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
return new_pos_embed
def resize_temporal_embedding(state_dict, key, num_frames):
_logger.info(f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}")
time_embed = state_dict[key].transpose(1, 2)
new_time_embed = F.interpolate(time_embed, size=(num_frames), mode='nearest')
return new_time_embed.transpose(1, 2)
================================================
FILE: src/modeling/timesformer/linear.py
================================================
""" Linear layer (alternate definition)
"""
import torch
import torch.nn.functional as F
from torch import nn as nn
class Linear(nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting():
bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
else:
return F.linear(input, self.weight, self.bias)
================================================
FILE: src/modeling/timesformer/operators.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# """Custom operators."""
# import torch
# import torch.nn as nn
# class Swish(nn.Module):
# """Swish activation function: x * sigmoid(x)."""
# def __init__(self):
# super(Swish, self).__init__()
# def forward(self, x):
# return SwishEfficient.apply(x)
# class SwishEfficient(torch.autograd.Function):
# """Swish activation function: x * sigmoid(x)."""
# @staticmethod
# def forward(ctx, x):
# result = x * torch.sigmoid(x)
# ctx.save_for_backward(x)
# return result
# @staticmethod
# def backward(ctx, grad_output):
# x = ctx.saved_variables[0]
# sigmoid_x = torch.sigmoid(x)
# return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
# class SE(nn.Module):
# """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
# def _round_width(self, width, multiplier, min_width=8, divisor=8):
# """
# Round width of filters based on width multiplier
# Args:
# width (int): the channel dimensions of the input.
# multiplier (float): the multiplication factor.
# min_width (int): the minimum width after multiplication.
# divisor (int): the new width should be dividable by divisor.
# """
# if not multiplier:
# return width
# width *= multiplier
# min_width = min_width or divisor
# width_out = max(
# min_width, int(width + divisor / 2) // divisor * divisor
# )
# if width_out < 0.9 * width:
# width_out += divisor
# return int(width_out)
# def __init__(self, dim_in, ratio, relu_act=True):
# """
# Args:
# dim_in (int): the channel dimensions of the input.
# ratio (float): the channel reduction ratio for squeeze.
# relu_act (bool): whether to use ReLU activation instead
# of Swish (default).
# divisor (int): the new width should be dividable by divisor.
# """
# super(SE, self).__init__()
# self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
# dim_fc = self._round_width(dim_in, ratio)
# self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True)
# self.fc1_act = nn.ReLU() if relu_act else Swish()
# self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True)
# self.fc2_sig = nn.Sigmoid()
# def forward(self, x):
# x_in = x
# for module in self.children():
# x = module(x)
# return x_in * x
================================================
FILE: src/modeling/timesformer/vit.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copyright 2020 Ross Wightman
# Modified Model definition
import torch
import torch.nn as nn
from functools import partial
import math
import warnings
import torch.nn.functional as F
import numpy as np
import torch.utils
import torch.utils.checkpoint
from src.modeling.timesformer.vit_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from src.modeling.timesformer.helpers import load_pretrained, load_pretrained_kinetics, load_pretrained_imagenet, load_pretrained_CLIP_ViT
from src.modeling.timesformer.vit_utils import DropPath, to_2tuple, trunc_normal_
from src.modeling.xbert import BertAttention
# from .build import MODEL_REGISTRY
from torch import einsum
from einops import rearrange, reduce, repeat
import src.utils.grad_ckpt as grad_ckpt
from src.utils.logger import LOGGER, TB_LOGGER, add_log_to_file, RunningMeter
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'vit_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
}
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.with_qkv = with_qkv
if self.with_qkv:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_drop = nn.Dropout(attn_drop)
def forward(self, x):
B, N, C = x.shape
if self.with_qkv:
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
else:
qkv = x.reshape(B, N, self.num_heads, C //
self.num_heads).permute(0, 2, 1, 3)
q, k, v = qkv, qkv, qkv
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
if self.with_qkv:
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, layer_num, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_type='divided_space_time',
use_grad_checkpointing=False):
super().__init__()
self.attention_type = attention_type
assert(attention_type in ['divided_space_time',
'space_only', 'joint_space_time'])
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# Temporal Attention Parameters
if self.attention_type == 'divided_space_time':
self.temporal_norm1 = norm_layer(dim)
self.temporal_attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.temporal_fc = nn.Linear(dim, dim)
# drop path
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# [dxli]
self.layer_num = layer_num
self.use_grad_checkpointing = use_grad_checkpointing
def forward(self, x, B, T, W):
num_spatial_tokens = (x.size(1) - 1) // T
H = num_spatial_tokens // W
if self.attention_type in ['space_only', 'joint_space_time']:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
elif self.attention_type == 'divided_space_time':
# Temporal
xt = x[:, 1:, :]
xt = rearrange(xt, 'b (h w t) m -> (b h w) t m',
b=B, h=H, w=W, t=T)
if self.use_grad_checkpointing:
# temporal_attn_out = torch.utils.checkpoint.checkpoint(self.temporal_attn, self.temporal_norm1(xt))
temporal_attn_out = grad_ckpt.CheckpointFunction.apply(self.temporal_attn, 1, self.temporal_norm1(xt))
else:
temporal_attn_out = self.temporal_attn(self.temporal_norm1(xt))
# res_temporal = self.drop_path(
# self.temporal_attn(self.temporal_norm1(xt)))
res_temporal = self.drop_path(temporal_attn_out)
res_temporal = rearrange(
res_temporal, '(b h w) t m -> b (h w t) m', b=B, h=H, w=W, t=T)
res_temporal = self.temporal_fc(res_temporal)
xt = x[:, 1:, :] + res_temporal
# Spatial
init_cls_token = x[:, 0, :].unsqueeze(1)
cls_token = init_cls_token.repeat(1, T, 1)
cls_token = rearrange(
cls_token, 'b t m -> (b t) m', b=B, t=T).unsqueeze(1)
xs = xt
xs = rearrange(xs, 'b (h w t) m -> (b t) (h w) m',
b=B, h=H, w=W, t=T)
xs = torch.cat((cls_token, xs), 1)
# [origial]
# res_spatial = self.drop_path(self.attn(self.norm1(xs)))
if self.use_grad_checkpointing:
spatial_attn_out = grad_ckpt.CheckpointFunction.apply(self.attn, 1, self.norm1(xs))
else:
# spatial_attn_out = torch.utils.checkpoint.checkpoint(self.attn, self.norm1(xs))
spatial_attn_out = self.attn(self.norm1(xs))
res_spatial = self.drop_path(spatial_attn_out)
# Taking care of CLS token
cls_token = res_spatial[:, 0, :]
cls_token = rearrange(cls_token, '(b t) m -> b t m', b=B, t=T)
# averaging for every frame
cls_token = torch.mean(cls_token, 1, True)
res_spatial = res_spatial[:, 1:, :]
res_spatial = rearrange(
res_spatial, '(b t) (h w) m -> b (h w t) m', b=B, h=H, w=W, t=T)
res = res_spatial
x = xt
# Mlp
x = torch.cat((init_cls_token, x), 1) + \
torch.cat((cls_token, res), 1)
x_res = x
x = self.norm2(x)
# x = x + self.drop_path(self.mlp(self.norm2(x)))
# MLP
# [origial]
# x = x_res + self.drop_path(self.mlp(x))
if self.use_grad_checkpointing:
# mlp_out = torch.utils.checkpoint.checkpoint(self.mlp, x)
mlp_out = grad_ckpt.CheckpointFunction.apply(self.mlp, 1, x)
else:
mlp_out = self.mlp(x)
x = x_res + self.drop_path(mlp_out)
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * \
(img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, T, H, W = x.shape
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.proj(x)
W = x.size(-1)
x = x.flatten(2).transpose(1, 2)
return x, T, W
class VisionTransformer(nn.Module):
""" Vision Transformere
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0.1, hybrid_backbone=None, norm_layer=nn.LayerNorm, num_frames=8, attention_type='divided_space_time', dropout=0.,
cross_attention_config=None, use_grad_checkpointing=False):
super().__init__()
self.attention_type = attention_type
self.depth = depth
self.dropout = nn.Dropout(dropout)
self.num_classes = num_classes
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
# Positional Embeddings
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
if self.attention_type != 'space_only':
self.time_embed = nn.Parameter(
torch.zeros(1, num_frames, embed_dim))
self.time_drop = nn.Dropout(p=drop_rate)
# Attention Blocks
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
self.depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(layer_num=i, use_grad_checkpointing=use_grad_checkpointing,
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, attention_type=self.attention_type)
for i in range(self.depth)])
self.norm = norm_layer(embed_dim)
# Classifier head
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
# initialization of temporal attention weights
if self.attention_type == 'divided_space_time':
i = 0
for m in self.blocks.modules():
m_str = str(m)
if 'Block' in m_str:
if i > 0:
nn.init.constant_(m.temporal_fc.weight, 0)
nn.init.constant_(m.temporal_fc.bias, 0)
i += 1
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'time_embed'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, return_all_tokens=False):
B = x.shape[0]
x, T, W = self.patch_embed(x)
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# resizing the positional embeddings in case they don't match the input at inference
if x.size(1) != self.pos_embed.size(1):
pos_embed = self.pos_embed
cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1)
other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2)
P = int(other_pos_embed.size(2) ** 0.5)
H = x.size(1) // W
other_pos_embed = other_pos_embed.reshape(1, x.size(2), P, P)
new_pos_embed = F.interpolate(
other_pos_embed, size=(H, W), mode='nearest')
new_pos_embed = new_pos_embed.flatten(2)
new_pos_embed = new_pos_embed.transpose(1, 2)
new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1)
x = x + new_pos_embed
else:
x = x + self.pos_embed
x = self.pos_drop(x)
# Time Embeddings
if self.attention_type != 'space_only':
cls_tokens = x[:B, 0, :].unsqueeze(1)
x = x[:, 1:]
x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)
# Resizing time embeddings in case they don't match
if T != self.time_embed.size(1):
time_embed = self.time_embed.transpose(1, 2)
new_time_embed = F.interpolate(
time_embed, size=(T), mode='nearest')
new_time_embed = new_time_embed.transpose(1, 2)
x = x + new_time_embed
else:
x = x + self.time_embed
x = self.time_drop(x)
x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T)
x = torch.cat((cls_tokens, x), dim=1)
# Attention blocks
for blk in self.blocks:
x = blk(x, B, T, W)
# Predictions for space-only baseline
if self.attention_type == 'space_only':
x = rearrange(x, '(b t) n m -> b t n m', b=B, t=T)
x = torch.mean(x, 1) # averaging predictions for every frame
x = self.norm(x)
if return_all_tokens:
return x
else:
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _conv_filter(state_dict, patch_size=16):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
if v.shape[-1] != patch_size:
patch_size = v.shape[-1]
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
out_dict[k] = v
return out_dict
class vit_base_patch16_224(nn.Module):
def __init__(self, cfg, **kwargs):
super(vit_base_patch16_224, self).__init__()
self.pretrained = True
patch_size = 16
self.model = VisionTransformer(img_size=cfg.DATA.TRAIN_CROP_SIZE, num_classes=cfg.MODEL.NUM_CLASSES, patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(
nn.LayerNorm, eps=1e-6), drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, num_frames=cfg.DATA.NUM_FRAMES, attention_type=cfg.TIMESFORMER.ATTENTION_TYPE, **kwargs)
self.attention_type = cfg.TIMESFORMER.ATTENTION_TYPE
self.model.default_cfg = default_cfgs['vit_base_patch16_224']
self.num_patches = (cfg.DATA.TRAIN_CROP_SIZE // patch_size) * \
(cfg.DATA.TRAIN_CROP_SIZE // patch_size)
pretrained_model = cfg.TIMESFORMER.PRETRAINED_MODEL
if self.pretrained:
load_pretrained(self.model, num_classes=self.model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter,
img_size=cfg.DATA.TRAIN_CROP_SIZE, num_patches=self.num_patches, attention_type=self.attention_type, pretrained_model=pretrained_model)
def forward(self, x):
x = self.model(x)
return x
class TimeSformer(nn.Module):
def __init__(self, model_cfg, input_format='BGR', cross_attention_config=None, **kwargs):
super(TimeSformer, self).__init__()
self.config_file = str(model_cfg)
# model-specific configurations
self.img_size = model_cfg['img_size']
self.patch_size = model_cfg['patch_size']
self.num_frames = model_cfg['num_frm']
self.attn_drop_rate = model_cfg['attn_drop_rate']
self.drop_path_rate = model_cfg['drop_path_rate']
self.drop_rate = model_cfg['drop_rate']
self.use_pooling = model_cfg['use_maxpooling']
self.use_grad_ckpt = model_cfg['gradient_checkpointing']
self.attention_type = 'divided_space_time'
LOGGER.info(f'Initializing TimeSformer with img_size={self.img_size}, patch_size={self.patch_size}, num_frames={self.num_frames}')
# will be ignored when loading official pretrained ckpt
self.num_classes = 400
self.input_format = input_format
assert input_format == "RGB", "Official TimeSformer uses RGB input."
self.model = VisionTransformer(img_size=self.img_size,
num_classes=self.num_classes,
patch_size=self.patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
drop_rate=self.drop_rate,
attn_drop_rate=self.attn_drop_rate,
drop_path_rate=self.drop_path_rate,
num_frames=self.num_frames,
attention_type=self.attention_type,
cross_attention_config=cross_attention_config,
use_grad_checkpointing=self.use_grad_ckpt,
**kwargs
)
if self.use_pooling:
self.maxpool_kernel_size = model_cfg['maxpool_kernel_size']
self.maxpooling = torch.nn.MaxPool2d(kernel_size=self.maxpool_kernel_size)
self.model.default_cfg = default_cfgs['vit_base_patch' + str(self.patch_size)+'_224']
self.num_patches = (self.img_size // self.patch_size) * (self.img_size // self.patch_size)
def forward(self, x):
x = self.model(x)
return x
def forward_features(self, x, return_all_tokens=True, pooling='temporal'):
b, c, t, h, w = x.shape
x = self.model.forward_features(x, return_all_tokens=return_all_tokens)
## apply pooling
W = H = self.img_size // self.patch_size
T = self.num_frames
cls_tokens = x[:, 0, :].unsqueeze(1)
other_tokens = x[:, 1:, :]
x = rearrange(other_tokens, 'b (h w t) m -> b t (h w) m', h=H, w=W, t=T)
assert pooling in ['temporal', 'spatial', 'none'], 'Invalid pooling type {}'.format(pooling)
if pooling == 'temporal':
x = torch.mean(x, dim=1)
x = torch.cat((cls_tokens, x), dim=1)
elif pooling == 'spatial': # spatial pooling
# x = torch.max(x, dim=2)[0]
x = torch.mean(x, dim=2)
x = torch.cat((cls_tokens, x), dim=1)
elif pooling == 'none':
cls_tokens_repeat = cls_tokens.unsqueeze(1).repeat(1, T, 1, 1)
x = torch.cat((cls_tokens_repeat, x), dim=2)
else:
raise NotImplementedError('Unsupported pooling type {}'.format(pooling))
return x
def _get_pooled_features(self, x):
b, t, h, w, c = x.shape
# x = rarrange(x.transpose(2, 4).transpose(3, 4), 'b t h w c -> (b t c) h w')
x = rearrange(x, 'b t h w c -> (b t c) h w')
x = self.maxpooling(x)
x = rearrange(x, '(b t c) h w -> b (t h w) c', b=b, t=t)
return x
def load_state_dict(self, pretrained_ckpt_path):
LOGGER.info('Loading TimeSformer checkpoints from {}'.format(pretrained_ckpt_path))
if pretrained_ckpt_path == "vit_base_patch16_224":
load_ckpt_func = load_pretrained_imagenet
elif "CLIP_ViT" in pretrained_ckpt_path:
load_ckpt_func = load_pretrained_CLIP_ViT
else:
load_ckpt_func = load_pretrained_kinetics
load_ckpt_func(self.model,
num_classes=self.model.num_classes,
in_chans=3,
filter_fn=_conv_filter,
img_size=self.img_size,
num_frames=self.num_frames,
num_patches=self.num_patches,
attention_type=self.attention_type,
pretrained_model=pretrained_ckpt_path
)
================================================
FILE: src/modeling/timesformer/vit_utils.py
================================================
# Copyright 2020 Ross Wightman
# Various utility functions
import torch
import torch.nn as nn
from functools import partial
import math
import warnings
import torch.nn.functional as F
from src.modeling.timesformer.helpers import load_pretrained
from itertools import repeat
import collections.abc as container_abcs
DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def get_padding_value(padding, kernel_size, **kwargs):
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = get_padding(kernel_size, **kwargs)
return padding, dynamic
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
return max((int(math.ceil(x // s)) - 1) * s + (k - 1) * d + 1 - x, 0)
# Can SAME padding for given args be done statically?
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
# Dynamically pad input x with 'SAME' padding for conv with specified args
#def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
def pad_same(x, k, s, d=(1, 1), value= 0):
ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
return x
def adaptive_pool_feat_mult(pool_type='avg'):
if pool_type == 'catavgmax':
return 2
else:
return 1
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
================================================
FILE: src/modeling/transformers.py
================================================
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model. """
import logging
import math
import os
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.activations import gelu, gelu_new, swish
from transformers.configuration_bert import BertConfig
from transformers.file_utils import (
add_start_docstrings, add_start_docstrings_to_callable)
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bert-base-uncased",
"bert-large-uncased",
"bert-base-cased",
"bert-large-cased",
"bert-base-multilingual-uncased",
"bert-base-multilingual-cased",
"bert-base-chinese",
"bert-base-german-cased",
"bert-large-uncased-whole-word-masking",
"bert-large-cased-whole-word-masking",
"bert-large-uncased-whole-word-masking-finetuned-squad",
"bert-large-cased-whole-word-masking-finetuned-squad",
"bert-base-cased-finetuned-mrpc",
"bert-base-german-dbmdz-cased",
"bert-base-german-dbmdz-uncased",
"cl-tohoku/bert-base-japanese",
"cl-tohoku/bert-base-japanese-whole-word-masking",
"cl-tohoku/bert-base-japanese-char",
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
"TurkuNLP/bert-base-finnish-cased-v1",
"TurkuNLP/bert-base-finnish-uncased-v1",
"wietsedv/bert-base-dutch-cased",
# See all BERT models at https://huggingface.co/models?filter=bert
]
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model.
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch,"
" requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ "
"for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split("/")
# adam_v and adam_m are variables used
# in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer",
"AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, scope_names[0])
except AttributeError:
logger.info("Skipping {}".format("/".join(name)))
continue
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu,
"swish": swish, "gelu_new": gelu_new, "mish": mish}
BertLayerNorm = LayerNorm
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(
config.vocab_size, config.hidden_size,
padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with
# TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids=None, token_type_ids=None,
position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = input_ids.device if input_ids is not None\
else inputs_embeds.device
if position_ids is None:
position_ids = torch.arange(
seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
if token_type_ids is None:
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = (
inputs_embeds + position_embeddings + token_type_embeddings)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and\
not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(
config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key"
# to get the raw attention scores.
attention_scores = torch.matmul(
query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is
# (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (
self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs
) if self.output_attentions else (context_layer,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.self.num_attention_heads,
self.self.attention_head_size)
heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are
# before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(
heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
self_outputs = self.self(
hidden_states, attention_mask, head_mask,
encoder_hidden_states, encoder_attention_mask
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = BertAttention(config)
self.is_decoder = config.is_decoder
if self.is_decoder:
self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
self_attention_outputs = self.attention(
hidden_states, attention_mask, head_mask)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output, attention_mask, head_mask,
encoder_hidden_states, encoder_attention_mask
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + outputs
return outputs
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(
config.num_hidden_layers)])
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states, attention_mask, head_mask[i],
encoder_hidden_states, encoder_attention_mask
)
hidden_states = layer_outputs[0]
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(
config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(
mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
================================================
FILE: src/modeling/xbert.py
================================================
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model. """
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import Tensor, device, dtype, nn
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MaskedLMOutput,
MultipleChoiceModelOutput,
NextSentencePredictorOutput,
QuestionAnsweringModelOutput,
SequenceClassifierOutput,
TokenClassifierOutput,
)
from transformers.modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.utils import logging
from transformers.models.bert.configuration_bert import BertConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"bert-base-uncased",
"bert-large-uncased",
"bert-base-cased",
"bert-large-cased",
"bert-base-multilingual-uncased",
"bert-base-multilingual-cased",
"bert-base-chinese",
"bert-base-german-cased",
"bert-large-uncased-whole-word-masking",
"bert-large-cased-whole-word-masking",
"bert-large-uncased-whole-word-masking-finetuned-squad",
"bert-large-cased-whole-word-masking-finetuned-squad",
"bert-base-cased-finetuned-mrpc",
"bert-base-german-dbmdz-cased",
"bert-base-german-dbmdz-uncased",
"cl-tohoku/bert-base-japanese",
"cl-tohoku/bert-base-japanese-whole-word-masking",
"cl-tohoku/bert-base-japanese-char",
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
"TurkuNLP/bert-base-finnish-cased-v1",
"TurkuNLP/bert-base-finnish-uncased-v1",
"wietsedv/bert-base-dutch-cased",
# See all BERT models at https://huggingface.co/models?filter=bert
]
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
"""Load tf checkpoints in a pytorch model."""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
logger.error(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
logger.info("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
for n in name
):
logger.info("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
scope_names = re.split(r"_(\d+)", m_name)
else:
scope_names = [m_name]
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
pointer = getattr(pointer, "bias")
elif scope_names[0] == "output_weights":
pointer = getattr(pointer, "weight")
elif scope_names[0] == "squad":
pointer = getattr(pointer, "classifier")
else:
try:
pointer = getattr(pointer, scope_names[0])
except AttributeError:
logger.info("Skipping {}".format("/".join(name)))
continue
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config, is_cross_attention):
super().__init__()
self.config = config
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
if is_cross_attention:
self.key = nn.Linear(config.encoder_width, self.all_head_size)
self.value = nn.Linear(config.encoder_width, self.all_head_size)
else:
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.save_attention = False
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length = hidden_states.size()[1]
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
if is_cross_attention and self.save_attention:
self.save_attention_map(attention_probs)
attention_probs.register_hook(self.save_attn_gradients)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs_dropped = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs_dropped = attention_probs_dropped * head_mask
context_layer = torch.matmul(attention_probs_dropped, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
outputs = outputs + (past_key_value,)
return outputs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config, is_cross_attention=False):
super().__init__()
self.self = BertSelfAttention(config, is_cross_attention)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config, layer_num):
super().__init__()
self.config = config
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
# self.has_cross_attention = (layer_num >= config.fusion_layer)
self.has_cross_attention = False
self.layer_num = layer_num
if self.has_cross_attention:
self.crossattention = BertAttention(config, is_cross_attention=True)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if self.has_cross_attention:
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
if type(encoder_hidden_states) == list:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states[(self.layer_num-self.config.fusion_layer)%len(encoder_hidden_states)],
encoder_attention_mask[(self.layer_num-self.config.fusion_layer)%len(encoder_hidden_states)],
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1]
else:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
mode='multi_modal',
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
if mode=='text':
start_layer = 0
output_layer = self.config.fusion_layer
elif mode=='fusion':
start_layer = self.config.fusion_layer
output_layer = self.config.num_hidden_layers
elif mode=='multi_modal':
start_layer = 0
output_layer = self.config.num_hidden_layers
for i in range(start_layer, output_layer):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@dataclass
class BertForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.BertForPreTraining`.
Args:
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
Total loss as the sum of the masked language modeling loss and the next sequence prediction
(classification) loss.
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
prediction_logits: torch.FloatTensor = None
seq_relationship_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
BERT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
pruning heads etc.)
This model is also a PyTorch `torch.nn.Module `__
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
general usage and behavior.
Parameters:
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
weights.
"""
BERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
1]``:
- 0 corresponds to a `sentence A` token,
- 1 corresponds to a `sentence B` token.
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
@add_start_docstrings(
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
BERT_START_DOCSTRING,
)
class BertModel(BertPreTrainedModel):
"""
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
input to the forward pass.
"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (:obj:`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (:obj:`Tuple[int]`):
The shape of the input to the model.
device: (:obj:`torch.device`):
The device of the input to the model.
Returns:
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if is_decoder:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones(
(batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
mode='multi_modal',
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = inputs_embeds.device
elif encoder_embeds is not None:
input_shape = encoder_embeds.size()[:-1]
batch_size, seq_length = input_shape
device = encoder_embeds.device
else:
raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
device, is_decoder)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if type(encoder_hidden_states) == list:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if type(encoder_attention_mask) == list:
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if encoder_embeds is None:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
else:
embedding_output = encoder_embeds
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
mode=mode,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
@add_start_docstrings(
"""
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
sentence prediction (classification)` head.
""",
BERT_START_DOCSTRING,
)
class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated.
Returns:
Example::
>>> from transformers import BertTokenizer, BertForPreTraining
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.prediction_logits
>>> seq_relationship_logits = outputs.seq_relationship_logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
if not return_dict:
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=True,
reduction='mean',
mode='multi_modal',
soft_labels=None,
alpha=0,
return_logits=False,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
Returns:
Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased")
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
mode=mode,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores[:, :-1, :].contiguous()
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction=reduction)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
if soft_labels is not None:
loss_distill = -torch.sum(F.log_softmax(shifted_prediction_scores, dim=1)*soft_labels,dim=-1)
loss_distill = (loss_distill * (labels!=-100)).sum(1)
lm_loss = (1-alpha)*lm_loss + alpha*loss_distill
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
def _reorder_cache(self, past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
class BertForMaskedLM(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
self.init_weights()
def get_output_embeddings(self):
return self.cls.predictions.decoder
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
is_decoder=False,
mode='multi_modal',
soft_labels=None,
alpha=0,
return_logits=False,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_embeds=encoder_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
is_decoder=is_decoder,
mode=mode,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
if return_logits:
return prediction_scores
masked_lm_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if soft_labels is not None:
loss_distill = -torch.sum(F.log_softmax(prediction_scores, dim=1)*soft_labels,dim=-1)
loss_distill = loss_distill[labels!=-100].mean()
masked_lm_loss = (1-alpha)*masked_lm_loss + alpha*loss_distill
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
# add a dummy token
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
)
class BertForNextSentencePrediction(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config)
self.init_weights()
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.
Returns:
Example::
>>> from transformers import BertTokenizer, BertForNextSentencePrediction
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
"""
if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
seq_relationship_scores = self.cls(pooled_output)
next_sentence_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
if not return_dict:
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return NextSentencePredictorOutput(
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
output) e.g. for GLUE tasks.
""",
BERT_START_DOCSTRING,
)
class BertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
softmax) e.g. for RocStories/SWAG tasks.
""",
BERT_START_DOCSTRING,
)
class BertForMultipleChoice(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=MultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
BERT_START_DOCSTRING,
)
class BertForTokenClassification(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
BERT_START_DOCSTRING,
)
class BertForQuestionAnswering(BertPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased",
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
================================================
FILE: src/optimization/adamw.py
================================================
"""
AdamW optimizer (weight decay fix)
copied from hugginface
"""
import math
import torch
from torch.optim import Optimizer
class AdamW(Optimizer):
""" Implements Adam algorithm with weight decay fix.
Parameters:
lr (float): learning rate. Default 1e-3.
betas (tuple of 2 floats): Adams beta parameters (b1, b2).
Default: (0.9, 0.999)
eps (float): Adams epsilon. Default: 1e-6
weight_decay (float): Weight decay. Default: 0.0
correct_bias (bool): can be set to False to avoid correcting bias
in Adam (e.g. like in Bert TF repository). Default True.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0.0, correct_bias=True):
if lr < 0.0:
raise ValueError(
"Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter: {} - "
"should be in [0.0, 1.0[".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter: {} - "
"should be in [0.0, 1.0[".format(betas[1]))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - "
"should be >= 0.0".format(eps))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
correct_bias=correct_bias)
super(AdamW, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
'Adam does not support sparse '
'gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
step_size = group['lr']
if group['correct_bias']: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state['step']
bias_correction2 = 1.0 - beta2 ** state['step']
step_size = (step_size * math.sqrt(bias_correction2)
/ bias_correction1)
p.data.addcdiv_(-step_size, exp_avg, denom)
# Just adding the square of the weights to the loss function is
# *not* the correct way of using L2 regularization/weight decay
# with Adam, since that will interact with the m and v
# parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't
# interact with the m/v parameters. This is equivalent to
# adding the square of the weights to the loss with plain
# (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group['weight_decay'] > 0.0:
p.data.add_(-group['lr'] * group['weight_decay'], p.data)
return loss
================================================
FILE: src/optimization/sched.py
================================================
"""
optimizer learning rate scheduling helpers
"""
from math import ceil
from collections import Counter
def noam_schedule(step, warmup_step=4000):
if step <= warmup_step:
return step / warmup_step
return (warmup_step ** 0.5) * (step ** -0.5)
def warmup_linear(step, warmup_step, tot_step):
if step < warmup_step:
return step / warmup_step
return max(0, (tot_step-step)/(tot_step-warmup_step))
def multi_step_schedule(n_epoch, milestones, gamma=0.5):
milestones = list(sorted(milestones))
for i, m in enumerate(milestones):
if n_epoch < m:
return gamma**i
return gamma**(len(milestones)+1)
def get_lr_sched(global_step, decay, learning_rate,
num_train_steps, warmup_ratio=0.1,
decay_epochs=[], multi_step_epoch=-1):
warmup_steps = int(warmup_ratio*num_train_steps)
if decay == 'linear':
lr_this_step = learning_rate * warmup_linear(
global_step, warmup_steps, num_train_steps)
elif decay == 'invsqrt':
lr_this_step = learning_rate * noam_schedule(
global_step, warmup_steps)
elif decay == 'constant':
lr_this_step = learning_rate
elif decay == "multi_step":
assert multi_step_epoch >= 0
lr_this_step = learning_rate * multi_step_schedule(
multi_step_epoch, decay_epochs)
if lr_this_step <= 0:
# save guard for possible miscalculation of train steps
lr_this_step = 1e-8
return lr_this_step
================================================
FILE: src/optimization/utils.py
================================================
from torch.optim import Adam, Adamax, SGD
from src.optimization.adamw import AdamW
def setup_e2e_optimizer(model, opts):
if opts.optim == 'adam':
OptimCls = Adam
elif opts.optim == 'adamax':
OptimCls = Adamax
elif opts.optim == 'adamw':
OptimCls = AdamW
else:
raise ValueError('invalid optimizer')
optimizer = OptimCls(model.parameters(), lr=opts.learning_rate, betas=opts.betas)
return optimizer
================================================
FILE: src/pretrain/run_pretrain_contrastive_only.py
================================================
import os
import torch
import time
import random
import pprint
import math
import json
from transformers import BertConfig, BertTokenizerFast
from src.datasets.dataset_pretrain_sparse import AlproPretrainSparseDataset, PretrainImageTextDataset, PretrainCollator
from src.datasets.dataloader import MetaLoader, PrefetchLoader
from src.datasets.data_utils import ImageNorm, mk_input_group
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from src.configs.config import shared_configs
from src.utils.misc import set_random_seed, NoOp, zero_none_grad
from src.utils.logger import LOGGER, TB_LOGGER, add_log_to_file, RunningMeter
from src.utils.basic_utils import load_jsonl, load_json, read_dataframe
from src.utils.load_save import (ModelSaver,
save_training_meta,
load_state_dict_with_pos_embed_resizing)
from src.utils.load_save import E2E_TrainingRestorer as TrainingRestorer
from src.optimization.sched import get_lr_sched
from src.optimization.utils import setup_e2e_optimizer
from collections import defaultdict
from tqdm import tqdm
from os.path import join
from apex import amp
from torch.utils.data.distributed import DistributedSampler
import horovod.torch as hvd
from src.utils.distributed import all_gather_list
from src.modeling.alpro_models import Prompter
def mk_captions_pretrain_dataloader(dataset_name, anno_path, video_dir, txt_dir, cfg, tokenizer,
is_train=True, max_txt_len=80):
# make a list(dict), where each dict {vis_id: int, txt: str}
if dataset_name == "webvid2m":
datalist = read_dataframe(anno_path)
datalist = datalist[datalist['txt_len'] < max_txt_len]
LOGGER.info('Found {} entries for webvid2m'.format(len(datalist)))
elif dataset_name == "cc3m":
datalist = json.load(open(anno_path))
LOGGER.info('Found {} entries for cc3m'.format(len(datalist)))
else:
raise ValueError("Invalid dataset_name")
if dataset_name in ["webvid2m"]:
frm_sampling_strategy = cfg.frm_sampling_strategy
if not is_train and frm_sampling_strategy == "rand":
frm_sampling_strategy = "uniform"
dataset = AlproPretrainSparseDataset(
datalist=datalist,
tokenizer=tokenizer,
img_lmdb_dir=video_dir,
img_db_type='rawvideo',
txt_dir=txt_dir,
crop_size=cfg.crop_img_size,
resize_size=cfg.resize_size,
max_txt_len=cfg.max_txt_len,
use_itm=cfg.use_itm,
fps=cfg.fps,
num_frm=cfg.num_frm,
frm_sampling_strategy=frm_sampling_strategy,
is_train=is_train
# vis_format=vis_format
)
elif dataset_name in ["cc3m"]:
dataset = PretrainImageTextDataset(datalist=datalist,
tokenizer=tokenizer,
crop_size=cfg.crop_img_size,
resize_size=cfg.resize_size,
max_txt_len=cfg.max_txt_len,
num_frm=cfg.num_frm
)
LOGGER.info(f"[{dataset_name}] is_train {is_train} "
f"dataset size {len(dataset)}, ")
batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
sampler = DistributedSampler(
dataset, num_replicas=hvd.size(), rank=hvd.rank(),
shuffle=is_train)
data_collator = PretrainCollator(tokenizer=tokenizer,
mlm=cfg.use_mlm,
mlm_probability=0.15,
max_length=cfg.max_txt_len,
is_train=is_train)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
num_workers=cfg.n_workers,
pin_memory=cfg.pin_mem,
collate_fn=data_collator.collate_batch)
return dataloader
def setup_dataloaders(cfg, tokenizer):
LOGGER.info("Init. train_loader and val_loader...")
train_loaders = {}
for db in cfg.train_datasets:
train_loaders[db.name] = mk_captions_pretrain_dataloader(
dataset_name=db.name,
anno_path=db.ann, video_dir=db.img, txt_dir=db.txt,
cfg=cfg, tokenizer=tokenizer, is_train=True
)
val_loaders = {}
for db in cfg.val_datasets:
val_loaders[db.name] = mk_captions_pretrain_dataloader(
dataset_name=db.name,
anno_path=db.ann, video_dir=db.img, txt_dir=db.txt,
cfg=cfg, tokenizer=tokenizer, is_train=False
)
return train_loaders, val_loaders
def setup_model(cfg, device=None):
LOGGER.info("Setup model...")
# has to be a BertConfig instance
model_cfg = load_json(cfg.model_config)
model_cfg = BertConfig(**model_cfg)
# add model-specific config
add_attr_list = [
"max_n_example_per_group",
"num_entities"
]
for k in add_attr_list:
setattr(model_cfg, k, cfg[k])
LOGGER.info(f"model_cfg {pprint.pformat(model_cfg.to_dict())}")
LOGGER.info("setup e2e model")
if cfg.model_type == 'pretrain':
# initialize cnn config
video_enc_cfg = load_json(cfg.visual_model_cfg)
video_enc_cfg['num_frm'] = cfg.num_frm
video_enc_cfg['img_size'] = cfg.crop_img_size
model = Prompter(
model_cfg,
input_format=cfg.img_input_format,
video_enc_cfg=video_enc_cfg
)
if cfg.e2e_weights_path:
LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}")
num_patches = (cfg.crop_img_size // video_enc_cfg['patch_size']) ** 2
# NOTE strict if False if loaded from ALBEF ckpt
load_state_dict_with_pos_embed_resizing(model,
cfg.e2e_weights_path,
num_patches=num_patches,
num_frames=cfg.num_frm,
strict=not cfg.albef_init
)
else:
LOGGER.info(f"Loading visual weights from {cfg.visual_weights_path}")
LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}")
model.load_separate_ckpt(
visual_weights_path=cfg.visual_weights_path,
bert_weights_path=cfg.bert_weights_path
)
else:
raise NotImplementedError(f"cfg.model_type not found {cfg.model_type}.")
# if cfg.freeze_cnn:
# model.freeze_cnn_backbone()
LOGGER.info("Moving model to device")
model.to(device)
LOGGER.info("Completed moving model to device.")
LOGGER.info("Setup model done!")
return model
def forward_step(cfg, model, batch):
"""shared for training and validation"""
# used to make visual feature copies
if not cfg.use_itm:
batch["itm_labels"] = None
outputs = model(batch) # dict
return outputs
@torch.no_grad()
def validate(model, val_loader, cfg):
model.eval()
n_itc_ex = 0
n_t2i_corrects = 0
n_i2t_corrects = 0
itc_loss = 0
st = time.time()
val_log = {'valid/itc_loss': 0,
'valid/i2t_acc': 0,
'valid/t2i_acc': 0
}
debug_step = 5
val_loaders = val_loader if isinstance(val_loader, dict) else {
"unnamed_val_loader": val_loader}
total_val_iters = 0
LOGGER.info(f"In total {len(val_loaders)} val loaders")
for loader_name, val_loader in val_loaders.items():
LOGGER.info(f"Loop val_loader {loader_name}.")
total_val_iters += len(val_loader)
for val_step, batch in enumerate(val_loader):
# use iter to reset MetaLoader
# forward pass
outputs = forward_step(cfg, model, batch)
assert not cfg.use_itm and not cfg.use_mlm
if cfg.use_itc:
itc_loss += outputs["itc_loss"].sum().item()
if cfg.debug and val_step >= debug_step:
break
# Gather across all processes
all_gather_itc_loss = all_gather_list(itc_loss)
itc_loss = sum(all_gather_itc_loss)
# FIXME check this whether take mean?
assert cfg.use_itc, 'cfg.use_itc is False for contrastive-only pretraining.'
val_log.update({
'valid/itc_loss': float(itc_loss),
})
n_itc_ex += len(outputs["itc_labels"])
n_t2i_corrects += (
outputs["t2i_scores"].max(
dim=-1)[1] == outputs["itc_labels"]).sum().item()
n_i2t_corrects += (
outputs["i2t_scores"].max(
dim=-1)[1] == outputs["itc_labels"]).sum().item()
n_i2t_corrects = sum(all_gather_list(n_i2t_corrects))
n_t2i_corrects = sum(all_gather_list(n_t2i_corrects))
n_itc_ex = sum(all_gather_list(n_itc_ex))
if n_itc_ex != 0:
val_log.update({
'valid/i2t_acc': float(n_i2t_corrects / n_itc_ex),
'valid/t2i_acc': float(n_t2i_corrects / n_itc_ex)
})
TB_LOGGER.log_scalar_dict(val_log)
LOGGER.info(f"validation finished in {int(time.time() - st)} seconds, ")
LOGGER.info("[itc_loss]: {} ".format(itc_loss))
LOGGER.info("In total, {} validation iters.".format(total_val_iters))
model.train()
return val_log
def start_training():
cfg = shared_configs.get_sparse_pretraining_args()
set_random_seed(cfg.seed)
n_gpu = hvd.size()
# device = torch.device("cuda", hvd.local_rank())
# torch.cuda.set_device(hvd.local_rank())
# This resolves the issue GPU 0 always has more processes running and more GPU-RAM.
# c.f. https://github.com/horovod/horovod/issues/2625#issuecomment-868134876
os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank())
device = torch.device("cuda", 0)
torch.cuda.set_device(0)
if hvd.rank() != 0:
LOGGER.disabled = True
LOGGER.info(f"device: {device} n_gpu: {n_gpu}, "
f"rank: {hvd.rank()}, 16-bits training: {cfg.fp16}")
model = setup_model(cfg, device=device)
model.train()
optimizer = setup_e2e_optimizer(model, cfg)
# Horovod: (optional) compression algorithm.compressin
compression = hvd.Compression.none
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression)
# Horovod: broadcast parameters & optimizer state.
compression = hvd.Compression.none
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
model, optimizer = amp.initialize(
model, optimizer, enabled=cfg.fp16, opt_level='O2',
keep_batchnorm_fp32=True)
# prepare data
tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
train_loaders, val_loaders = setup_dataloaders(cfg, tokenizer)
train_loader = MetaLoader(train_loaders,
accum_steps=cfg.gradient_accumulation_steps,
distributed=n_gpu > 1)
img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
train_loader = PrefetchLoader(train_loader, img_norm)
val_loaders = {k: PrefetchLoader(v, img_norm)
for k, v in val_loaders.items()}
# compute the number of steps and update cfg
total_train_batch_size = int(
n_gpu * cfg.train_batch_size *
cfg.gradient_accumulation_steps * cfg.max_n_example_per_group)
total_n_epochs = cfg.num_train_epochs
cfg.num_train_steps = int(math.ceil(
1. * train_loader.n_batches_in_epoch * total_n_epochs /
(n_gpu * cfg.gradient_accumulation_steps)))
cfg.valid_steps = int(math.ceil(
1. * cfg.num_train_steps / cfg.num_valid /
cfg.min_valid_steps)) * cfg.min_valid_steps
actual_num_valid = int(math.floor(
1. * cfg.num_train_steps / cfg.valid_steps)) + 1
# restore
restorer = TrainingRestorer(cfg, model, optimizer)
global_step = restorer.global_step
TB_LOGGER.global_step = global_step
if hvd.rank() == 0:
LOGGER.info("Saving training meta...")
save_training_meta(cfg)
LOGGER.info("Saving training done...")
TB_LOGGER.create(join(cfg.output_dir, 'log'))
pbar = tqdm(total=cfg.num_train_steps)
model_saver = ModelSaver(join(cfg.output_dir, "ckpt"))
add_log_to_file(join(cfg.output_dir, "log", "log.txt"))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
restorer = NoOp()
if global_step > 0:
pbar.update(global_step)
LOGGER.info(cfg)
LOGGER.info("Starting training...")
LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
LOGGER.info(f" Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}")
LOGGER.info(f" max_n_example_per_group = {cfg.max_n_example_per_group}")
LOGGER.info(f" Accumulate steps = {cfg.gradient_accumulation_steps}")
LOGGER.info(f" Total batch size = #GPUs * Single-GPU batch size * "
f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}")
LOGGER.info(f" Total #batches - single epoch = {train_loader.n_batches_in_epoch}.")
LOGGER.info(f" Total #steps = {cfg.num_train_steps}")
LOGGER.info(f" Total #epochs = {total_n_epochs}.")
LOGGER.info(f" Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times")
# quick hack for amp delay_unscale bug
with optimizer.skip_synchronize():
optimizer.zero_grad()
if global_step == 0:
optimizer.step()
debug_step = 5
tasks = []
for name, flag in zip(["itc"], [cfg.use_itc]):
if flag:
tasks.append(name)
task2loss = {t: RunningMeter(f'train_loss/{t}')
for t in tasks}
task2loss["loss"] = RunningMeter('train_loss/loss')
train_log = {'train/i2t_acc': 0,
'train/t2i_acc': 0}
for step, (task, batch) in enumerate(train_loader):
# forward pass
outputs = forward_step(cfg, model, batch)
# mlm_loss, itm_loss, itc_loss, mpm_loss = 0, 0, 0, 0
itc_loss = 0
assert not cfg.use_mlm and not cfg.use_itm
if cfg.use_itc:
n_itc_ex = len(outputs["itc_labels"])
n_t2i_corrects = (
outputs["t2i_scores"].max(
dim=-1)[1] == outputs["itc_labels"]).sum().item()
n_i2t_corrects = (
outputs["i2t_scores"].max(
dim=-1)[1] == outputs["itc_labels"]).sum().item()
train_log.update({
'train/t2i_acc': float(n_t2i_corrects / n_itc_ex),
'train/i2t_acc': float(n_i2t_corrects / n_itc_ex),
# 'train/mpm_acc': mpm_acc
})
itc_loss = outputs["itc_loss"]
task2loss["itc"](itc_loss.item())
loss = itc_loss
task2loss["loss"](loss.item())
delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0
with amp.scale_loss(
loss, optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
zero_none_grad(model)
optimizer.synchronize()
# optimizer
if (step + 1) % cfg.gradient_accumulation_steps == 0:
global_step += 1
if (step + 1) % cfg.log_interval == 0:
TB_LOGGER.log_scalar_dict({l.name: l.val
for l in task2loss.values()
if l.val is not None})
n_epoch = int(1. * n_gpu * cfg.gradient_accumulation_steps *
global_step / train_loader.n_batches_in_epoch)
# learning rate scheduling for the whole model
lr_this_step = get_lr_sched(
global_step, cfg.decay, cfg.learning_rate,
cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio,
decay_epochs=cfg.step_decay_epochs,
multi_step_epoch=n_epoch)
# Hardcoded param group length
# assert len(optimizer.param_groups) == 8
for pg_n, param_group in enumerate(
optimizer.param_groups):
param_group['lr'] = lr_this_step
if (step + 1) % cfg.log_interval == 0:
TB_LOGGER.add_scalar(
"train/lr", lr_this_step, global_step)
# update model params
if cfg.grad_norm != -1:
# import pdb; pdb.set_trace()
grad_norm = clip_grad_norm_(
amp.master_params(optimizer), cfg.grad_norm)
if (step + 1) % cfg.log_interval == 0:
TB_LOGGER.add_scalar("train/grad_norm", grad_norm, global_step)
TB_LOGGER.step()
# Check if there is None grad
none_grads = [
p[0] for p in model.named_parameters()
if p[1].requires_grad and p[1].grad is None]
assert len(none_grads) == 0, f"{none_grads}"
with optimizer.skip_synchronize():
optimizer.step()
optimizer.zero_grad()
restorer.step()
pbar.update(1)
# validate and checkpoint
if global_step % cfg.valid_steps == 0:
LOGGER.info(f'Step {global_step}: start validation')
validate(model, val_loaders, cfg)
model_saver.save(step=global_step, model=model)
if global_step >= cfg.num_train_steps:
break
if cfg.debug and global_step >= debug_step:
break
if global_step % cfg.valid_steps != 0:
LOGGER.info(f'Step {global_step}: start validation')
validate(model, val_loaders, cfg)
model_saver.save(step=global_step, model=model)
if __name__ == '__main__':
# Initialize Horovod
hvd.init()
start_training()
================================================
FILE: src/pretrain/run_pretrain_sparse.py
================================================
import os
import torch
import time
import random
import pprint
import math
import json
from transformers import BertConfig, BertTokenizerFast
from src.datasets.dataset_pretrain_sparse import AlproPretrainSparseDataset, PretrainImageTextDataset, PretrainCollator
from src.datasets.dataloader import MetaLoader, PrefetchLoader
from src.datasets.data_utils import ImageNorm, mk_input_group
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from src.configs.config import shared_configs
from src.utils.misc import set_random_seed, NoOp, zero_none_grad
from src.utils.logger import LOGGER, TB_LOGGER, add_log_to_file, RunningMeter
from src.utils.basic_utils import load_jsonl, load_json, read_dataframe
from src.utils.load_save import (ModelSaver,
save_training_meta,
load_state_dict_with_pos_embed_resizing)
from src.utils.load_save import E2E_TrainingRestorer as TrainingRestorer
from src.optimization.sched import get_lr_sched
from src.optimization.utils import setup_e2e_optimizer
from collections import defaultdict
from tqdm import tqdm
from os.path import join
from apex import amp
from torch.utils.data.distributed import DistributedSampler
import horovod.torch as hvd
from src.utils.distributed import all_gather_list
from src.modeling.alpro_models import AlproForPretrain
def mk_captions_pretrain_dataloader(dataset_name, anno_path, video_dir, txt_dir, cfg, tokenizer,
is_train=True, max_txt_len=80):
# make a list(dict), where each dict {vis_id: int, txt: str}
if dataset_name == "webvid2m":
datalist = read_dataframe(anno_path)
datalist = datalist[datalist['txt_len'] < max_txt_len]
LOGGER.info('Found {} entries for webvid2m'.format(len(datalist)))
elif dataset_name == "cc3m":
datalist = json.load(open(anno_path))
LOGGER.info('Found {} entries for cc3m'.format(len(datalist)))
else:
raise ValueError("Invalid dataset_name")
if dataset_name in ["webvid2m"]:
frm_sampling_strategy = cfg.frm_sampling_strategy
if not is_train and frm_sampling_strategy == "rand":
frm_sampling_strategy = "uniform"
dataset = AlproPretrainSparseDataset(
datalist=datalist,
tokenizer=tokenizer,
img_lmdb_dir=video_dir,
img_db_type='rawvideo',
txt_dir=txt_dir,
crop_size=cfg.crop_img_size,
resize_size=cfg.resize_size,
max_txt_len=cfg.max_txt_len,
use_itm=cfg.use_itm,
fps=cfg.fps,
num_frm=cfg.num_frm,
frm_sampling_strategy=frm_sampling_strategy,
is_train=is_train
# vis_format=vis_format
)
elif dataset_name in ["cc3m"]:
dataset = PretrainImageTextDataset(datalist=datalist,
tokenizer=tokenizer,
crop_size=cfg.crop_img_size,
resize_size=cfg.resize_size,
max_txt_len=cfg.max_txt_len,
num_frm=cfg.num_frm
)
LOGGER.info(f"[{dataset_name}] is_train {is_train} "
f"dataset size {len(dataset)}, ")
batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
sampler = DistributedSampler(
dataset, num_replicas=hvd.size(), rank=hvd.rank(),
shuffle=is_train)
data_collator = PretrainCollator(tokenizer=tokenizer,
mlm=cfg.use_mlm,
mlm_probability=0.15,
max_length=cfg.max_txt_len,
mpm=cfg.use_mpm,
is_train=is_train)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
num_workers=cfg.n_workers,
pin_memory=cfg.pin_mem,
collate_fn=data_collator.collate_batch)
return dataloader
def setup_dataloaders(cfg, tokenizer):
LOGGER.info("Init. train_loader and val_loader...")
train_loaders = {}
for db in cfg.train_datasets:
train_loaders[db.name] = mk_captions_pretrain_dataloader(
dataset_name=db.name,
anno_path=db.ann, video_dir=db.img, txt_dir=db.txt,
cfg=cfg, tokenizer=tokenizer, is_train=True
)
val_loaders = {}
for db in cfg.val_datasets:
val_loaders[db.name] = mk_captions_pretrain_dataloader(
dataset_name=db.name,
anno_path=db.ann, video_dir=db.img, txt_dir=db.txt,
cfg=cfg, tokenizer=tokenizer, is_train=False
)
return train_loaders, val_loaders
def setup_model(cfg, device=None):
LOGGER.info("Setup model...")
# has to be a BertConfig instance
model_cfg = load_json(cfg.model_config)
model_cfg = BertConfig(**model_cfg)
# add model-specific config
add_attr_list = [
"max_n_example_per_group",
"num_entities"
]
for k in add_attr_list:
setattr(model_cfg, k, cfg[k])
LOGGER.info(f"model_cfg {pprint.pformat(model_cfg.to_dict())}")
LOGGER.info("setup e2e model")
if cfg.model_type == 'pretrain':
# initialize cnn config
video_enc_cfg = load_json(cfg.visual_model_cfg)
video_enc_cfg['num_frm'] = cfg.num_frm
video_enc_cfg['img_size'] = cfg.crop_img_size
model = AlproForPretrain(
model_cfg,
input_format=cfg.img_input_format,
video_enc_cfg=video_enc_cfg
)
if cfg.e2e_weights_path:
LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}")
num_patches = (cfg.crop_img_size // video_enc_cfg['patch_size']) ** 2
# NOTE strict if False if loaded from ALBEF ckpt
load_state_dict_with_pos_embed_resizing(model,
cfg.e2e_weights_path,
num_patches=num_patches,
num_frames=cfg.num_frm,
strict=True
)
else:
LOGGER.info(f"Loading visual weights from {cfg.visual_weights_path}")
model.load_separate_ckpt(
visual_weights_path=cfg.visual_weights_path,
prompter_weights_path=cfg.teacher_weights_path
)
else:
raise NotImplementedError(f"cfg.model_type not found {cfg.model_type}.")
# if cfg.freeze_cnn:
# model.freeze_cnn_backbone()
LOGGER.info("Moving model to device")
model.to(device)
LOGGER.info("Completed moving model to device.")
LOGGER.info("Setup model done!")
return model
def forward_step(cfg, model, batch):
"""shared for training and validation"""
# used to make visual feature copies
if not cfg.use_itm:
batch["itm_labels"] = None
outputs = model(batch) # dict
return outputs
@torch.no_grad()
def validate(model, val_loader, cfg):
model.eval()
mlm_loss = 0
n_mlm_tokens = 0
n_mlm_corrects = 0
itm_loss = 0
n_itm_ex = 0
n_itm_corrects = 0
itc_loss = 0
mpm_loss = 0
n_mpm_ex = 0
n_mpm_corrects = 0
st = time.time()
val_log = {'valid/mlm_loss': 0, 'valid/mlm_acc': 0,
'valid/itm_loss': 0, 'valid/itm_acc': 0,
'valid/mpm_loss': 0, 'valid/mpm_acc': 0,
'valid/itc_loss': 0}
debug_step = 5
val_loaders = val_loader if isinstance(val_loader, dict) else {
"unnamed_val_loader": val_loader}
total_val_iters = 0
LOGGER.info(f"In total {len(val_loaders)} val loaders")
for loader_name, val_loader in val_loaders.items():
LOGGER.info(f"Loop val_loader {loader_name}.")
total_val_iters += len(val_loader)
for val_step, batch in enumerate(val_loader):
# use iter to reset MetaLoader
# forward pass
outputs = forward_step(cfg, model, batch)
# mlm
mlm_labels = outputs["mlm_labels"]
if cfg.use_mlm:
mlm_loss += outputs["mlm_loss"].sum().item()
mlm_mask = mlm_labels != -100 # (B, Lt) -100 is the ignored label for cross entropy
n_mlm_tokens += mlm_mask.sum().item()
if n_mlm_tokens > 0:
n_mlm_corrects += (
outputs["mlm_scores"][mlm_mask].max(
dim=-1)[1] == mlm_labels[mlm_mask]).sum().item()
else:
n_mlm_corrects = 0
# itm
if cfg.use_itm:
itm_loss += outputs["itm_loss"].sum().item()
n_itm_ex += len(outputs["itm_labels"])
n_itm_corrects += (
outputs["itm_scores"].max(
dim=-1)[1] == outputs["itm_labels"]).sum().item()
if cfg.use_itc:
itc_loss += outputs["itc_loss"].sum().item()
if cfg.use_mpm:
mpm_labels = outputs["mpm_labels"]
if mpm_labels is not None:
n_mpm_ex += len(mpm_labels)
n_mpm_corrects += (
outputs["mpm_logits"].max(
dim=-1)[1] == outputs["mpm_labels"].max(dim=-1)[1]).sum().item()
mpm_loss += outputs["mpm_loss"].sum().item()
if cfg.debug and val_step >= debug_step:
break
# Gather across all processes
# mlm_loss = sum(all_gather_list(mlm_loss))
all_gather_mlm_loss = all_gather_list(mlm_loss)
mlm_loss = sum(all_gather_mlm_loss)
n_mlm_corrects = sum(all_gather_list(n_mlm_corrects))
n_mlm_tokens = sum(all_gather_list(n_mlm_tokens))
all_gather_itm_loss = all_gather_list(itm_loss)
itm_loss = sum(all_gather_itm_loss)
n_itm_corrects = sum(all_gather_list(n_itm_corrects))
n_itm_ex = sum(all_gather_list(n_itm_ex))
all_gather_itc_loss = all_gather_list(itc_loss)
itc_loss = sum(all_gather_itc_loss)
all_gather_mpm_loss = all_gather_list(mpm_loss)
mpm_loss = sum(all_gather_mpm_loss)
n_mpm_corrects = sum(all_gather_list(n_mpm_corrects))
n_mpm_ex = sum(all_gather_list(n_mpm_ex))
if n_mlm_tokens != 0:
val_log.update({
'valid/mlm_loss': float(mlm_loss),
'valid/mlm_acc': float(n_mlm_corrects / n_mlm_tokens)
})
# FIXME check this whether take mean?
if n_itm_ex != 0:
val_log.update({
'valid/itm_loss': float(itm_loss),
'valid/itm_acc': float(n_itm_corrects / n_itm_ex)
})
# FIXME check this whether take mean?
if cfg.use_itc:
val_log.update({
'valid/itc_loss': float(itc_loss),
})
if n_mpm_ex != 0:
val_log.update({
'valid/mpm_loss': float(mpm_loss),
'valid/mpm_acc': float(n_mpm_corrects / n_mpm_ex)
})
TB_LOGGER.log_scalar_dict(val_log)
LOGGER.info(f"validation finished in {int(time.time() - st)} seconds, "
f"[mlm_acc (per token)]: {val_log['valid/mlm_acc'] * 100:.2f} "
f"[mpm_acc (per token)]: {val_log['valid/mpm_acc'] * 100:.2f} "
f"[itm_acc (per example)]: {val_log['valid/itm_acc'] * 100:.2f} ")
LOGGER.info("[mlm_loss]: {} ".format(mlm_loss))
LOGGER.info("[itm_loss]: {} ".format(itm_loss))
LOGGER.info("[itc_loss]: {} ".format(itc_loss))
LOGGER.info("In total, {} validation iters.".format(total_val_iters))
model.train()
return val_log
def get_video_prompt_templates():
prompts = [
'A footage of a {}.',
'A footage of the {}.',
'A footage of one {}.',
'A video of a {}.',
'A video of the {}.',
'A video of one {}.',
'A portrait of a {}.',
'A portrait of the {}.',
'A portrait of one {}.',
'A video footage of a {}.',
'A video footage of the {}.',
'A video footage of one {}.'
]
return prompts
def get_image_prompt_templates():
prompts = [
# basics
'A photo of a {}.',
'A photo of the {}.',
'A photo of one {}.',
'A picture of a {}.',
'A picture of the {}.',
'A picture of one {}.',
# good photo/picture
'A good photo of the {}.',
'A good photo of a {}.',
'A good photo of one {}.',
'A good picture of the {}.',
'A good picture of a {}.',
'A good picture of one {}.'
]
return prompts
def setup_text_prompts(cfg, tokenizer):
entity_filepath = cfg.entity_file_path
entity_num = cfg.num_entities
content = open(entity_filepath).read().split('\n')[:entity_num]
entities = [c.split(' ')[0] for c in content]
video_prompt_templates = get_video_prompt_templates()
image_prompt_templates = get_image_prompt_templates()
video_prompts = []
for template in video_prompt_templates:
video_prompts.extend([template.format(e) for e in entities])
image_prompts = []
for template in image_prompt_templates:
image_prompts.extend([template.format(e) for e in entities])
batch_enc_video_prompts = tokenizer.batch_encode_plus(
video_prompts,
max_length=15,
padding="max_length",
return_tensors="pt"
)
batch_enc_image_prompts = tokenizer.batch_encode_plus(
image_prompts,
max_length=15,
padding="max_length",
return_tensors="pt"
)
return dict(video_prompts=video_prompts,
image_prompts=image_prompts,
batch_enc_video_prompts=batch_enc_video_prompts,
batch_enc_image_prompts=batch_enc_image_prompts
)
def start_training():
os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = shared_configs.get_sparse_pretraining_args()
set_random_seed(cfg.seed)
n_gpu = hvd.size()
# device = torch.device("cuda", hvd.local_rank())
# torch.cuda.set_device(hvd.local_rank())
# This resolves the issue GPU 0 always has more processes running and more GPU-RAM.
# c.f. https://github.com/horovod/horovod/issues/2625#issuecomment-868134876
os.environ['CUDA_VISIBLE_DEVICES'] = str(hvd.local_rank())
device = torch.device("cuda", 0)
torch.cuda.set_device(0)
if hvd.rank() != 0:
LOGGER.disabled = True
LOGGER.info(f"device: {device} n_gpu: {n_gpu}, "
f"rank: {hvd.rank()}, 16-bits training: {cfg.fp16}")
model = setup_model(cfg, device=device)
model.train()
optimizer = setup_e2e_optimizer(model, cfg)
# Horovod: (optional) compression algorithm.compressin
compression = hvd.Compression.none
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression)
# Horovod: broadcast parameters & optimizer state.
compression = hvd.Compression.none
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
model, optimizer = amp.initialize(
model, optimizer, enabled=cfg.fp16, opt_level='O1')
# keep_batchnorm_fp32=True)
# prepare data
tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
train_loaders, val_loaders = setup_dataloaders(cfg, tokenizer)
train_loader = MetaLoader(train_loaders,
accum_steps=cfg.gradient_accumulation_steps,
distributed=n_gpu > 1)
img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
train_loader = PrefetchLoader(train_loader, img_norm)
val_loaders = {k: PrefetchLoader(v, img_norm)
for k, v in val_loaders.items()}
# compute the number of steps and update cfg
total_train_batch_size = int(
n_gpu * cfg.train_batch_size *
cfg.gradient_accumulation_steps * cfg.max_n_example_per_group)
total_n_epochs = cfg.num_train_epochs
cfg.num_train_steps = int(math.ceil(
1. * train_loader.n_batches_in_epoch * total_n_epochs /
(n_gpu * cfg.gradient_accumulation_steps)))
cfg.valid_steps = int(math.ceil(
1. * cfg.num_train_steps / cfg.num_valid /
cfg.min_valid_steps)) * cfg.min_valid_steps
actual_num_valid = int(math.floor(
1. * cfg.num_train_steps / cfg.valid_steps)) + 1
save_steps = int(cfg.save_steps_ratio * cfg.num_train_steps)
# restore
restorer = TrainingRestorer(cfg, model, optimizer)
global_step = restorer.global_step
TB_LOGGER.global_step = global_step
if hvd.rank() == 0:
LOGGER.info("Saving training meta...")
save_training_meta(cfg)
LOGGER.info("Saving training done...")
TB_LOGGER.create(join(cfg.output_dir, 'log'))
pbar = tqdm(total=cfg.num_train_steps)
model_saver = ModelSaver(join(cfg.output_dir, "ckpt"))
add_log_to_file(join(cfg.output_dir, "log", "log.txt"))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
restorer = NoOp()
if global_step > 0:
pbar.update(global_step)
LOGGER.info(cfg)
LOGGER.info("Starting training...")
LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
LOGGER.info(f" Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}")
LOGGER.info(f" max_n_example_per_group = {cfg.max_n_example_per_group}")
LOGGER.info(f" Accumulate steps = {cfg.gradient_accumulation_steps}")
LOGGER.info(f" Total batch size = #GPUs * Single-GPU batch size * "
f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}")
LOGGER.info(f" Total #batches - single epoch = {train_loader.n_batches_in_epoch}.")
LOGGER.info(f" Total #steps = {cfg.num_train_steps}")
LOGGER.info(f" Total #epochs = {total_n_epochs}.")
LOGGER.info(f" Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times")
# quick hack for amp delay_unscale bug
with optimizer.skip_synchronize():
optimizer.zero_grad()
if global_step == 0:
optimizer.step()
debug_step = 20
tasks = []
for name, flag in zip(["mlm", "itm", "itc", "mpm"], [cfg.use_mlm, cfg.use_itm, cfg.use_itc, cfg.use_mpm]):
if flag:
tasks.append(name)
task2loss = {t: RunningMeter(f'train_loss/{t}')
for t in tasks}
task2loss["loss"] = RunningMeter('train_loss/loss')
train_log = {'train/mlm_acc': 0,
'train/itm_acc': 0,
'train/mpm_acc': 0,
}
# create tokenized promopts
if not cfg.e2e_weights_path and cfg.use_mpm:
text_prompts = setup_text_prompts(cfg, tokenizer)
model.build_text_prompts(text_prompts)
for step, (task, batch) in enumerate(train_loader):
# forward pass
outputs = forward_step(cfg, model, batch)
mlm_loss, itm_loss, itc_loss, mpm_loss = 0, 0, 0, 0
# mlm_loss, itm_loss, itc_loss = 0, 0, 0
if cfg.use_mlm:
# mlm_loss = outputs["mlm_loss"].mean()
mlm_loss = outputs["mlm_loss"]
mlm_mask = outputs["mlm_labels"] != -100
n_mlm_tokens = mlm_mask.sum().item()
task2loss["mlm"](mlm_loss.item())
if cfg.use_itm:
itm_loss = outputs["itm_loss"]
task2loss["itm"](itm_loss.item())
if cfg.use_itc:
itc_loss = outputs["itc_loss"]
task2loss["itc"](itc_loss.item())
if cfg.use_mpm:
mpm_loss = outputs["mpm_loss"]
task2loss["mpm"](mpm_loss.item())
loss = mlm_loss + itm_loss + itc_loss + mpm_loss
task2loss["loss"](loss.item())
if step % cfg.log_interval == 0:
# training mlm token acc
if n_mlm_tokens > 0:
n_mlm_corrects = (
outputs["mlm_scores"][mlm_mask].max(
dim=-1)[1] == outputs['mlm_labels'][mlm_mask]).sum().item()
else:
n_mlm_corrects = 0
# training itm acc
n_itm_ex = len(outputs["itm_labels"])
n_itm_corrects = (
outputs["itm_scores"].max(
dim=-1)[1] == outputs["itm_labels"]).sum().item()
# training mpm acc
mpm_labels = outputs["mpm_labels"]
if mpm_labels is not None:
n_mpm_ex = len(mpm_labels)
n_mpm_corrects = (
outputs["mpm_logits"].max(
dim=-1)[1] == outputs["mpm_labels"].max(dim=-1)[1]).sum().item()
mpm_acc = float(n_mpm_corrects / n_mpm_ex)
else:
mpm_acc = 0.
train_log.update({
'train/mlm_acc': float(n_mlm_corrects / n_mlm_tokens),
'train/itm_acc': float(n_itm_corrects / n_itm_ex),
'train/mpm_acc': mpm_acc
})
TB_LOGGER.log_scalar_dict(train_log)
delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0
with amp.scale_loss(
loss, optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
zero_none_grad(model)
optimizer.synchronize()
# optimizer
if (step + 1) % cfg.gradient_accumulation_steps == 0:
global_step += 1
if (step + 1) % cfg.log_interval == 0:
TB_LOGGER.log_scalar_dict({l.name: l.val
for l in task2loss.values()
if l.val is not None})
n_epoch = int(1. * n_gpu * cfg.gradient_accumulation_steps *
global_step / train_loader.n_batches_in_epoch)
# learning rate scheduling for the whole model
lr_this_step = get_lr_sched(
global_step, cfg.decay, cfg.learning_rate,
cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio,
decay_epochs=cfg.step_decay_epochs,
multi_step_epoch=n_epoch)
# Hardcoded param group length
# assert len(optimizer.param_groups) == 8
for pg_n, param_group in enumerate(
optimizer.param_groups):
param_group['lr'] = lr_this_step
if (step + 1) % cfg.log_interval == 0:
TB_LOGGER.add_scalar(
"train/lr", lr_this_step, global_step)
# update model params
if cfg.grad_norm != -1:
# import pdb; pdb.set_trace()
grad_norm = clip_grad_norm_(
amp.master_params(optimizer), cfg.grad_norm)
if (step + 1) % cfg.log_interval == 0:
TB_LOGGER.add_scalar("train/grad_norm", grad_norm, global_step)
TB_LOGGER.step()
# Check if there is None grad
none_grads = [
p[0] for p in model.named_parameters()
if p[1].requires_grad and p[1].grad is None]
assert len(none_grads) == 0, f"{none_grads}"
with optimizer.skip_synchronize():
optimizer.step()
optimizer.zero_grad()
restorer.step()
pbar.update(1)
# validate and checkpoint
if global_step % cfg.valid_steps == 0:
LOGGER.info(f'Step {global_step}: start validation')
validate(model, val_loaders, cfg)
model_saver.save(step=global_step, model=model)
if global_step % save_steps == 0:
LOGGER.info(f'Step {global_step}: saving model checkpoints.')
model_saver.save(step=global_step, model=model)
if global_step >= cfg.num_train_steps:
break
if cfg.debug and global_step >= debug_step:
break
if global_step % cfg.valid_steps != 0:
LOGGER.info(f'Step {global_step}: start validation')
validate(model, val_loaders, cfg)
model_saver.save(step=global_step, model=model)
if __name__ == '__main__':
# Initialize Horovod
hvd.init()
start_training()
================================================
FILE: src/tasks/run_video_qa.py
================================================
import math
import os
import random
import time
from collections import defaultdict
from os.path import join
import horovod.torch as hvd
import torch
from apex import amp
from easydict import EasyDict as edict
from src.configs.config import shared_configs
from src.datasets.data_utils import ImageNorm, mk_input_group
from src.datasets.dataloader import InfiniteIterator, PrefetchLoader
from src.datasets.dataset_video_qa import (AlproVideoQADataset,
VideoQACollator)
from src.modeling.alpro_models import AlproForSequenceClassification
from src.optimization.sched import get_lr_sched
from src.optimization.utils import setup_e2e_optimizer
from src.utils.basic_utils import (get_rounded_percentage, load_json,
load_jsonl, save_json)
from src.utils.distributed import all_gather_list
from src.utils.load_save import E2E_TrainingRestorer as TrainingRestorer
from src.utils.load_save import (ModelSaver,
load_state_dict_with_pos_embed_resizing,
save_training_meta)
from src.utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from src.utils.misc import NoOp, set_random_seed, zero_none_grad
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import BertConfig, BertTokenizerFast
def mk_qa_dataloader(task_type, anno_path, lmdb_dir, cfg, tokenizer,
is_train=True, return_label=True):
"""
Returns:
list(dict), each dict is
msrvtt_qa: {
"answer": "couch",
"question": "what are three people sitting on?",
"video_id": "video6513",
"answer_type": "what"
}
"""
raw_datalist = load_jsonl(anno_path)
LOGGER.info(f"Loaded data size {len(raw_datalist)}")
if cfg.data_ratio != 1.0:
random.shuffle(raw_datalist)
raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]
LOGGER.info(f"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}")
datalist = []
qid = 0
for raw_d in raw_datalist:
d = dict(
question=raw_d["question"],
vid_id=raw_d["video_id"],
answer=raw_d["answer"], # int or str
question_id=qid # be careful, it is not unique across splits
)
qid += 1
d["answer_type"] = raw_d["answer_type"]
datalist.append(d)
LOGGER.info(f"datalist {len(datalist)}")
grouped = defaultdict(list) # examples grouped by image/video id
for d in datalist:
grouped[d["vid_id"]].append(d)
LOGGER.info(f"grouped {len(grouped)}")
# each group has a single image with multiple questions
group_datalist = mk_input_group(
grouped,
max_n_example_per_group=cfg.max_n_example_per_group if is_train else 1, # force 1 in eval,
is_train=is_train
)
LOGGER.info(f"group_datalist {len(group_datalist)}")
ans2label = load_json(cfg.ans2label_path)
frm_sampling_strategy = cfg.frm_sampling_strategy
if not is_train:
# frm_sampling_strategy = "middle"
frm_sampling_strategy = "uniform"
if 'msvd' in cfg.task:
video_fmt = '.avi'
else:
video_fmt = '.mp4'
dataset = AlproVideoQADataset(
task_type=cfg.task,
datalist=group_datalist,
tokenizer=tokenizer,
img_lmdb_dir=lmdb_dir,
ans2label=ans2label,
max_img_size=cfg.crop_img_size,
max_txt_len=cfg.max_txt_len,
fps=cfg.fps,
num_frm=cfg.num_frm,
frm_sampling_strategy=frm_sampling_strategy,
ensemble_n_clips=cfg.train_n_clips if is_train else cfg.inference_n_clips,
return_label=return_label,
is_train=is_train,
img_db_type='rawvideo',
video_fmt=video_fmt
)
LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, "
f"each group {cfg.max_n_example_per_group if is_train else 1}")
if cfg.do_inference:
batch_size = cfg.inference_batch_size
else:
batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
sampler = DistributedSampler(
dataset, num_replicas=hvd.size(), rank=hvd.rank(),
shuffle=is_train)
vqa_collator = VideoQACollator(tokenizer=tokenizer,
max_length=cfg.max_txt_len,
task_type=cfg.task)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
num_workers=cfg.n_workers,
pin_memory=cfg.pin_mem,
collate_fn=vqa_collator.collate_batch)
return dataloader
def setup_dataloaders(cfg, tokenizer):
LOGGER.info("Init. train_loader and val_loader...")
train_loader = mk_qa_dataloader(
task_type=cfg.task,
anno_path=cfg.train_datasets[0].txt[cfg.task],
lmdb_dir=cfg.train_datasets[0].img,
cfg=cfg, tokenizer=tokenizer, is_train=True
)
val_loader = mk_qa_dataloader(
task_type=cfg.task,
anno_path=cfg.val_datasets[0].txt[cfg.task],
lmdb_dir=cfg.val_datasets[0].img,
cfg=cfg, tokenizer=tokenizer, is_train=False, return_label=False
)
img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
train_loader = PrefetchLoader(train_loader, img_norm)
val_loader = PrefetchLoader(val_loader, img_norm)
return train_loader, val_loader
def setup_model(cfg, device=None):
LOGGER.info("Setup model...")
# has to be a BertConfig instance
model_cfg = load_json(cfg.model_config)
model_cfg = BertConfig(**model_cfg)
# add downstream model config
add_attr_list = [
"num_labels", "classifier", "cls_hidden_scale",
"loss_type",
]
for k in add_attr_list:
setattr(model_cfg, k, cfg[k])
transformer_model_cls = AlproForSequenceClassification
# we separate the CNN and the transformer in order to use different optimizer for each
# transformer still has a CNN layer inside, used to down sample grid.
LOGGER.info("setup e2e model")
video_enc_cfg = load_json(cfg.visual_model_cfg)
video_enc_cfg['num_frm'] = cfg.num_frm
video_enc_cfg['img_size'] = cfg.crop_img_size
model = AlproForSequenceClassification(
model_cfg,
input_format=cfg.img_input_format,
video_enc_cfg=video_enc_cfg
)
if cfg.e2e_weights_path:
LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}")
num_patches = (cfg.crop_img_size // video_enc_cfg['patch_size']) ** 2
# NOTE strict is False if loaded from ALBEF ckpt
load_state_dict_with_pos_embed_resizing(model,
cfg.e2e_weights_path,
num_patches=num_patches,
num_frames=cfg.num_frm,
strict=False,
remove_text_encoder_prefix=True
)
# LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}")
# load_state_dict_with_mismatch(model, cfg.e2e_weights_path)
else:
LOGGER.info(f"Loading visual weights from {cfg.visual_weights_path}")
LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}")
model.load_separate_ckpt(
visual_weights_path=cfg.visual_weights_path,
bert_weights_path=cfg.bert_weights_path
)
# if cfg.freeze_cnn:
# model.freeze_cnn_backbone()
model.to(device)
LOGGER.info("Setup model done!")
return model
def forward_step(model, batch, cfg):
"""shared for training and validation"""
if cfg.task in ["action", "transition"]:
repeat_counts = [e * cfg.num_labels for e in batch["n_examples_list"]]
batch["n_examples_list"] = repeat_counts
outputs = model(batch) # dict
return outputs
@torch.no_grad()
def validate(model, val_loader, cfg, train_global_step, eval_score=True):
"""use eval_score=False when doing inference on test sets where answers are not available"""
model.eval()
loss = 0.
n_ex = 0
qa_results = []
st = time.time()
debug_step = 5
pbar = tqdm(total=len(val_loader))
for val_step, batch in enumerate(val_loader):
# forward pass
question_ids = batch["question_ids"]
bsz = len(question_ids)
# used to make visual feature copies
del batch["question_ids"]
# add visual part into the mini batch and perform inference
mini_batch = dict()
for k, v in batch.items():
if k != "visual_inputs":
mini_batch[k] = v
n_ex += len(question_ids)
# multi-frame test, scores across frames of the same video will be pooled together
pool_method = cfg.score_agg_func
# could be 1, where only a single clip is evaluated
num_clips = cfg.inference_n_clips
num_frm = cfg.num_frm
# (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:]
visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
logits = []
losses = []
for clip_idx in range(num_clips):
# (B, num_frm, C, H, W)
mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
mini_batch["n_examples_list"] = batch["n_examples_list"]
outputs = forward_step(model, mini_batch, cfg)
logits.append(outputs["logits"].cpu())
_loss = outputs["loss"].sum().item() if isinstance(
outputs["loss"], torch.Tensor) else 0
losses.append(_loss)
loss += (sum(losses) / num_clips)
logits = torch.stack(logits) # (num_frm, B, 5)
if pool_method == "mean":
logits = logits.mean(0) # (B, 5)
elif pool_method == "max":
logits = logits.max(0)[0] # (B, 5)
elif pool_method == "lse":
logits = logits.permute(1, 0, 2).contiguous() # (B, num_frm, 5), pooling will be done in CE
logits = torch.logsumexp(logits, dim=1) # torch.exp alone might be too large and unstable
else:
raise ValueError(f"Invalid value for pool_method, "
f"got {pool_method}, expect one of [`mean`, `max`, `lse`]")
if cfg.task in ["action", "transition", "frameqa", "msrvtt_qa", "msvd_qa"]:
# cross entropy
pred_labels = logits.max(dim=-1)[1].data.tolist()
else:
# mse
preds = (logits + 0.5).long().clamp(min=1, max=10)
pred_labels = preds.data.squeeze().tolist()
for qid, pred_label in zip(question_ids, pred_labels):
qa_results.append(dict(
question_id=qid,
answer=pred_label,
data=val_loader.dataset.qid2data[qid]
))
pbar.update(1)
if cfg.debug and val_step >= debug_step:
break
if cfg.debug:
LOGGER.info(qa_results[:10])
n_ex_per_rank = all_gather_list(n_ex)
loss = sum(all_gather_list(loss))
n_ex = sum(all_gather_list(n_ex))
# average loss for each example
val_log = {f'valid/loss': float(loss / n_ex)}
if eval_score:
LOGGER.info(f"QA Task [{cfg.task}], "
f"{len(qa_results)} qa_results,"
f"3 examples here: {qa_results[:3]}")
vqa_scores = val_loader.dataset.evaluate_qa(qa_results)
# print(f"{hvd.rank()}: {vqa_scores}")
# Gather scores
scores_per_rank = all_gather_list(vqa_scores)
gathered_scores = {}
if "ratios" in scores_per_rank[0]:
gathered_ratios = {
k: [0, 0] for k, _ in scores_per_rank[0]["ratios"].items()}
# Gather ratios
for rank_id in range(len(n_ex_per_rank)):
current_ratios = scores_per_rank[rank_id]["ratios"]
for k, v in current_ratios.items():
gathered_ratios[k][1] += v[1]
for k, v in gathered_ratios.items():
gathered_ratios[k][0] = get_rounded_percentage(
1. * v[1] / n_ex)
gathered_scores["ratios"] = gathered_ratios
# FIXME: Gather scores become complicated due to np.mean and dict format.
for scores_k, _ in vqa_scores.items():
if "ratio" in scores_k:
continue
gathered_v = 0
for rank_id, n in enumerate(n_ex_per_rank):
curr_acc, curr_n_ex = 0, 0
if "overall" in scores_k:
curr_acc = scores_per_rank[rank_id][scores_k] * n
else:
if "ratios" in scores_per_rank[0]:
curr_n_ex = scores_per_rank[
rank_id]["ratios"][
scores_k.replace("acc", "ratio")][1]
curr_acc = scores_per_rank[rank_id][
scores_k] * curr_n_ex
gathered_v += curr_acc
if "overall" in scores_k:
gathered_v = gathered_v * 1. / n_ex
else:
if "ratios" in scores_per_rank[0]:
_num = gathered_ratios[
scores_k.replace("acc", "ratio")][1]
gathered_v = gathered_v * 1. / _num if _num != 0 else 0
if cfg.task in ["action", "transition", "frameqa", "msrvtt_qa", "msvd_qa"]:
gathered_scores[scores_k] = get_rounded_percentage(
gathered_v)
else:
gathered_scores[scores_k] = round(gathered_v, 2)
for k, v in gathered_scores.items():
if "ratio" not in k:
val_log[f'valid/{k}'] = v
else:
LOGGER.info("eval_score = False, no scores are calculated.")
gathered_scores = 0
TB_LOGGER.log_scalar_dict(val_log)
LOGGER.info(f"validation finished in {int(time.time() - st)} seconds."
f"{gathered_scores}")
model.train()
return qa_results, gathered_scores
def start_training(cfg):
set_random_seed(cfg.seed)
n_gpu = hvd.size()
cfg.n_gpu = n_gpu
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
if hvd.rank() != 0:
LOGGER.disabled = True
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), bool(cfg.fp16)))
model = setup_model(cfg, device=device)
model.train()
optimizer = setup_e2e_optimizer(model, cfg)
# Horovod: (optional) compression algorithm.compressin
compression = hvd.Compression.none
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression)
# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
model, optimizer = amp.initialize(
model, optimizer, enabled=cfg.fp16, opt_level='O2',
keep_batchnorm_fp32=True)
# prepare data
tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
train_loader, val_loader = setup_dataloaders(cfg, tokenizer)
# compute the number of steps and update cfg
total_n_examples = len(train_loader.dataset) * cfg.max_n_example_per_group
total_train_batch_size = int(
n_gpu * cfg.train_batch_size *
cfg.gradient_accumulation_steps * cfg.max_n_example_per_group)
cfg.num_train_steps = int(math.ceil(
1. * cfg.num_train_epochs * total_n_examples / total_train_batch_size))
cfg.valid_steps = int(math.ceil(
1. * cfg.num_train_steps / cfg.num_valid /
cfg.min_valid_steps)) * cfg.min_valid_steps
actual_num_valid = int(math.floor(
1. * cfg.num_train_steps / cfg.valid_steps)) + 1
# restore
restorer = TrainingRestorer(cfg, model, optimizer)
global_step = restorer.global_step
TB_LOGGER.global_step = global_step
if hvd.rank() == 0:
LOGGER.info("Saving training meta...")
save_training_meta(cfg)
LOGGER.info("Saving training done...")
TB_LOGGER.create(join(cfg.output_dir, 'log'))
pbar = tqdm(total=cfg.num_train_steps)
model_saver = ModelSaver(join(cfg.output_dir, "ckpt"))
add_log_to_file(join(cfg.output_dir, "log", "log.txt"))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
restorer = NoOp()
if global_step > 0:
pbar.update(global_step)
LOGGER.info(cfg)
LOGGER.info("Starting training...")
LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
LOGGER.info(f" Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}")
LOGGER.info(f" max_n_example_per_group = {cfg.max_n_example_per_group}")
LOGGER.info(f" Accumulate steps = {cfg.gradient_accumulation_steps}")
LOGGER.info(f" Total batch size = #GPUs * Single-GPU batch size * "
f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}")
LOGGER.info(f" Total #epochs = {cfg.num_train_epochs}")
LOGGER.info(f" Total #steps = {cfg.num_train_steps}")
LOGGER.info(f" Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times")
# quick hack for amp delay_unscale bug
with optimizer.skip_synchronize():
optimizer.zero_grad()
if global_step == 0:
optimizer.step()
debug_step = 3
running_loss = RunningMeter('train_loss')
for step, batch in enumerate(InfiniteIterator(train_loader)):
# forward pass
bsz = len(batch["question_ids"])
del batch["question_ids"]
mini_batch = dict()
for k, v in batch.items():
if k != "visual_inputs":
mini_batch[k] = v
pool_method = cfg.score_agg_func
# could be 1, where only a single clip is used
num_clips = cfg.train_n_clips
num_frm = cfg.num_frm
# (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:]
visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
logits = []
for clip_idx in range(num_clips):
# (B, num_frm, C, H, W)
mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
mini_batch["n_examples_list"] = batch["n_examples_list"]
# outputs = forward_step(model, mini_batch, cfg)
outputs = forward_step(model, mini_batch, cfg)
logits.append(outputs)
# the losses are cross entropy and mse, no need to * num_labels
loss = outputs['loss']
loss = loss.mean()
running_loss(loss.item())
# backward pass
delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0
with amp.scale_loss(
loss, optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
zero_none_grad(model)
optimizer.synchronize()
# optimizer
if (step + 1) % cfg.gradient_accumulation_steps == 0:
global_step += 1
# learning rate scheduling
n_epoch = int(1. * total_train_batch_size * global_step
/ total_n_examples)
# learning rate scheduling cnn
lr_this_step = get_lr_sched(
global_step, cfg.decay, cfg.learning_rate,
cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio,
decay_epochs=cfg.step_decay_epochs,
multi_step_epoch=n_epoch)
# Hardcoded param group length
for pg_n, param_group in enumerate(
optimizer.param_groups):
param_group['lr'] = lr_this_step
if step % cfg.log_interval == 0:
TB_LOGGER.add_scalar(
"train/lr", lr_this_step, global_step)
TB_LOGGER.add_scalar('train/loss', running_loss.val, global_step)
# update model params
if cfg.grad_norm != -1:
grad_norm = clip_grad_norm_(
amp.master_params(optimizer),
cfg.grad_norm)
TB_LOGGER.add_scalar(
"train/grad_norm", grad_norm, global_step)
TB_LOGGER.step()
# Check if there is None grad
none_grads = [
p[0] for p in model.named_parameters()
if p[1].requires_grad and p[1].grad is None]
assert len(none_grads) == 0, f"{none_grads}"
with optimizer.skip_synchronize():
optimizer.step()
optimizer.zero_grad()
restorer.step()
pbar.update(1)
# checkpoint
if global_step % cfg.valid_steps == 0:
LOGGER.info(f'Step {global_step}: start validation')
validate(
model, val_loader, cfg, global_step)
model_saver.save(step=global_step, model=model)
if global_step >= cfg.num_train_steps:
break
if cfg.debug and global_step >= debug_step:
break
if global_step % cfg.valid_steps != 0:
LOGGER.info(f'Step {global_step}: start validation')
qa_results, qa_scores = validate(
model, val_loader, cfg, global_step)
model_saver.save(step=global_step, model=model)
def start_inference(cfg):
set_random_seed(cfg.seed)
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
if hvd.rank() != 0:
LOGGER.disabled = True
inference_res_dir = join(
cfg.output_dir,
f"results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/"
f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}"
)
if hvd.rank() == 0:
os.makedirs(inference_res_dir, exist_ok=True)
save_json(cfg, join(inference_res_dir, "raw_args.json"),
save_pretty=True)
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), bool(cfg.fp16)))
# overwrite cfg with stored_cfg,
# but skip keys containing the keyword 'inference'
stored_cfg_path = join(cfg.output_dir, "log/args.json")
stored_cfg = edict(load_json(stored_cfg_path))
for k, v in cfg.items():
if k in stored_cfg and "inference" not in k:
setattr(cfg, k, stored_cfg[k])
# setup models
cfg.model_config = join(cfg.output_dir, "log/model_config.json")
e2e_weights_path = join(
cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt")
cfg.e2e_weights_path = e2e_weights_path
model = setup_model(cfg, device=device)
model.eval()
# FIXME separate scaling for each loss
model = amp.initialize(
model, enabled=cfg.fp16, opt_level='O2')
global_step = 0
# prepare data
tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
cfg.data_ratio = 1.
val_loader = mk_qa_dataloader(
task_type=cfg.task,
anno_path=cfg.inference_txt_db,
lmdb_dir=cfg.inference_img_db,
cfg=cfg, tokenizer=tokenizer,
is_train=False,
return_label=False
)
img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
val_loader = PrefetchLoader(val_loader, img_norm)
LOGGER.info(cfg)
LOGGER.info("Starting inference...")
LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****")
LOGGER.info(f" Batch size = {cfg.inference_batch_size}")
LOGGER.info(f'Step {global_step}: start validation')
qa_results, qa_scores = validate(
model, val_loader, cfg, global_step,
eval_score=True) # cfg.inference_split == "val"
if hvd.rank() == 0:
save_json(cfg, join(inference_res_dir, "merged_args.json"),
save_pretty=True)
save_json(qa_scores, join(inference_res_dir, "scores.json"),
save_pretty=True)
# ###### Saving with Horovod ####################
# dummy sync
_ = None
all_gather_list(_)
if n_gpu > 1:
# with retrial, as azure blob fails occasionally.
max_save_load_trial = 10
save_trial = 0
while save_trial < max_save_load_trial:
try:
LOGGER.info(f"Save results trial NO. {save_trial}")
save_json(
qa_results,
join(inference_res_dir, f"results_rank{hvd.rank()}.json"))
break
except Exception as e:
save_trial += 1
# dummy sync
_ = None
all_gather_list(_)
# join results
if n_gpu > 1 and hvd.rank() == 0:
qa_results = []
for rk in range(n_gpu):
qa_results.extend(load_json(
join(inference_res_dir, f"results_rank{rk}.json")))
LOGGER.info(f'results joined')
if hvd.rank() == 0:
save_json(
qa_results,
join(inference_res_dir, f"results_all.json"))
LOGGER.info(f'all results written')
if __name__ == '__main__':
# Initialize Horovod
hvd.init()
input_cfg = shared_configs.get_video_qa_args()
if input_cfg.do_inference:
# assert hvd.size() == 1, \
# "Please use single GPU for evaluation! " \
# "Multi-GPU might miss some examples."
start_inference(input_cfg)
else:
start_training(input_cfg)
================================================
FILE: src/tasks/run_video_retrieval.py
================================================
import json
import math
import os
import random
import time
from collections import defaultdict
from os.path import exists, join
import horovod.torch as hvd
import numpy as np
import torch
import torch.nn.functional as F
from apex import amp
from easydict import EasyDict as edict
from src.configs.config import shared_configs
from src.datasets.data_utils import ImageNorm, mk_input_group
from src.datasets.dataloader import InfiniteIterator, PrefetchLoader
from src.datasets.dataset_video_retrieval import (
AlproVideoRetrievalDataset, AlproVideoRetrievalEvalDataset,
VideoRetrievalCollator)
from src.modeling.alpro_models import AlproForVideoTextRetrieval
from src.optimization.sched import get_lr_sched
from src.optimization.utils import setup_e2e_optimizer
from src.utils.basic_utils import (get_rounded_percentage, load_json,
load_jsonl, save_json)
from src.utils.distributed import all_gather_list
from src.utils.load_save import E2E_TrainingRestorer as TrainingRestorer
from src.utils.load_save import (ModelSaver,
load_state_dict_with_pos_embed_resizing,
save_training_meta)
from src.utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from src.utils.misc import NoOp, set_random_seed, zero_none_grad
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import BertConfig, BertTokenizerFast
def mk_video_ret_datalist(raw_datalist, cfg):
"""
Args:
raw_datalist: list(dict)
Each data point is {id: int, txt: str, vid_id: str}
Returns:
"""
LOGGER.info(f"Loaded data size {len(raw_datalist)}")
if cfg.data_ratio != 1.0:
random.shuffle(raw_datalist)
raw_datalist = raw_datalist[:int(len(raw_datalist) * cfg.data_ratio)]
LOGGER.info(f"Use {100 * cfg.data_ratio}% of the loaded data: {len(raw_datalist)}")
datalist = []
qid = 0
for raw_d in raw_datalist:
d = dict(
id=qid,
txt=raw_d["caption"],
vid_id=raw_d["clip_name"]
)
qid += 1
datalist.append(d)
LOGGER.info(f"datalist {len(datalist)}")
return datalist
def mk_video_ret_dataloader(anno_path, lmdb_dir, cfg, tokenizer, is_train=True):
""""""
raw_datalist = load_jsonl(anno_path)
datalist = mk_video_ret_datalist(raw_datalist, cfg)
grouped = defaultdict(list) # examples grouped by image/video id
for d in datalist:
grouped[d["vid_id"]].append(d)
LOGGER.info(f"grouped {len(grouped)}")
# each group has a single image with multiple questions
group_datalist = mk_input_group(
grouped,
max_n_example_per_group=cfg.max_n_example_per_group if is_train else 1, # force 1 in eval,
is_train=is_train
)
LOGGER.info(f"group_datalist {len(group_datalist)}")
frm_sampling_strategy = cfg.frm_sampling_strategy
if not is_train and frm_sampling_strategy == "rand":
frm_sampling_strategy = "uniform"
if 'msvd' in cfg.train_datasets[0]['name']:
video_fmt = '.avi'
else:
video_fmt = '.mp4'
dataset = AlproVideoRetrievalDataset(
datalist=group_datalist,
tokenizer=tokenizer,
img_lmdb_dir=lmdb_dir,
max_img_size=cfg.crop_img_size,
max_txt_len=cfg.max_txt_len,
fps=cfg.fps,
num_frm=cfg.num_frm,
frm_sampling_strategy=frm_sampling_strategy,
itm_neg_size=0,
is_train=is_train,
img_db_type='rawvideo',
video_fmt=video_fmt
)
LOGGER.info(f"is_train {is_train}, dataset size {len(dataset)} groups, "
f"each group {cfg.max_n_example_per_group if is_train else 1}")
if cfg.do_inference:
batch_size = cfg.inference_batch_size
else:
batch_size = cfg.train_batch_size if is_train else cfg.val_batch_size
sampler = DistributedSampler(
dataset, num_replicas=hvd.size(), rank=hvd.rank(),
shuffle=is_train)
vqa_collator = VideoRetrievalCollator(
tokenizer=tokenizer, max_length=cfg.max_txt_len)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
num_workers=cfg.n_workers,
pin_memory=cfg.pin_mem,
collate_fn=vqa_collator.collate_batch)
return dataloader
def mk_video_ret_eval_dataloader(anno_path, lmdb_dir, cfg, tokenizer):
"""
eval_retrieval: bool, will sample one video per batch paired with multiple text.
Returns:
"""
raw_datalist = load_jsonl(anno_path)
datalist = mk_video_ret_datalist(raw_datalist, cfg)
frm_sampling_strategy = cfg.frm_sampling_strategy
if frm_sampling_strategy == "rand":
frm_sampling_strategy = "uniform"
if 'msvd' in cfg.train_datasets[0]['name']:
video_fmt = '.avi'
else:
video_fmt = '.mp4'
dataset = AlproVideoRetrievalEvalDataset(
datalist=datalist,
tokenizer=tokenizer,
img_lmdb_dir=lmdb_dir,
max_img_size=cfg.crop_img_size,
max_txt_len=cfg.max_txt_len,
fps=cfg.fps,
num_frm=cfg.num_frm,
frm_sampling_strategy=frm_sampling_strategy,
video_fmt=video_fmt,
img_db_type='rawvideo'
)
sampler = DistributedSampler(
dataset, num_replicas=hvd.size(), rank=hvd.rank(),
shuffle=False)
retrieval_collator = VideoRetrievalCollator(
tokenizer=tokenizer, max_length=cfg.max_txt_len)
dataloader = DataLoader(dataset,
batch_size=1, # already batched in dataset
shuffle=False,
sampler=sampler,
num_workers=cfg.n_workers,
pin_memory=cfg.pin_mem,
collate_fn=retrieval_collator.collate_batch)
img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
dataloader = PrefetchLoader(dataloader, img_norm)
return dataloader
def setup_dataloaders(cfg, tokenizer):
LOGGER.info("Init. train_loader and val_loader...")
train_loader = mk_video_ret_dataloader(
anno_path=cfg.train_datasets[0].txt,
lmdb_dir=cfg.train_datasets[0].img,
cfg=cfg, tokenizer=tokenizer, is_train=True
)
val_loader = mk_video_ret_dataloader(
anno_path=cfg.val_datasets[0].txt,
lmdb_dir=cfg.val_datasets[0].img,
cfg=cfg, tokenizer=tokenizer, is_train=False
)
img_norm = ImageNorm(mean=cfg.img_pixel_mean, std=cfg.img_pixel_std)
train_loader = PrefetchLoader(train_loader, img_norm)
val_loader = PrefetchLoader(val_loader, img_norm)
return train_loader, val_loader
def setup_model(cfg, device=None):
LOGGER.info("Setup model...")
# has to be a BertConfig instance
model_cfg = load_json(cfg.model_config)
model_cfg = BertConfig(**model_cfg)
# add downstream model config
add_attr_list = []
for k in add_attr_list:
setattr(model_cfg, k, cfg[k])
# we separate the CNN and the transformer in order to use different optimizer for each
# transformer still has a CNN layer inside, used to down sample grid.
LOGGER.info("setup e2e model")
video_enc_cfg = load_json(cfg.visual_model_cfg)
video_enc_cfg['num_frm'] = cfg.num_frm
video_enc_cfg['img_size'] = cfg.crop_img_size
model = AlproForVideoTextRetrieval(
model_cfg,
input_format=cfg.img_input_format,
video_enc_cfg=video_enc_cfg
)
if cfg.e2e_weights_path:
LOGGER.info(f"Loading e2e weights from {cfg.e2e_weights_path}")
num_patches = (cfg.crop_img_size // video_enc_cfg['patch_size']) ** 2
# NOTE strict if False if loaded from ALBEF ckpt
load_state_dict_with_pos_embed_resizing(model,
cfg.e2e_weights_path,
num_patches=num_patches,
num_frames=cfg.num_frm,
strict=False,
)
else:
LOGGER.info(f"Loading visual weights from {cfg.visual_weights_path}")
LOGGER.info(f"Loading bert weights from {cfg.bert_weights_path}")
model.load_separate_ckpt(
visual_weights_path=cfg.visual_weights_path,
bert_weights_path=cfg.bert_weights_path
)
# if cfg.freeze_cnn:
# model.freeze_cnn_backbone()
model.to(device)
LOGGER.info("Setup model done!")
return model
def forward_step(model, batch):
"""shared for training and validation"""
outputs = model(batch) # dict
return outputs
def forward_inference_step(model, batch):
outputs = model.forward_inference(batch)
return outputs
@torch.no_grad()
def validate(model, val_loader, eval_loader, cfg, train_global_step, eval_filepath):
"""use eval_score=False when doing inference on test sets where answers are not available"""
model.eval()
loss = 0.
n_ex = 0
n_corrects = 0
st = time.time()
debug_step = 10
for val_step, batch in enumerate(val_loader):
# forward pass
del batch["caption_ids"]
outputs = forward_step(model, batch)
targets = batch['labels']
batch_loss = outputs['itm_loss'] + outputs['itc_loss']
if isinstance(batch_loss, torch.Tensor):
loss += batch_loss.sum().item()
else:
raise NotImplementedError('Expecting loss as Tensor, found: {}'.format(type(loss)))
# n_ex += len(targets)
n_ex += len(targets)
if cfg.debug and val_step >= debug_step:
break
loss = sum(all_gather_list(loss))
n_ex = sum(all_gather_list(n_ex))
n_corrects = sum(all_gather_list(n_corrects))
_, retrieval_metrics = inference_retrieval(model, eval_loader, eval_filepath, cfg)
model.train()
if hvd.rank() == 0:
# average loss for each example
acc = float(n_corrects / n_ex)
val_log = {'valid/loss': float(loss / n_ex), 'valid/acc': acc}
for ret_type, ret_m in retrieval_metrics.items():
val_log.update({f"valid/{ret_type}_{k}": round(v, 4) for k, v in ret_m.items()})
TB_LOGGER.log_scalar_dict(val_log)
LOGGER.info(f"validation finished in {int(time.time() - st)} seconds."
f"itm_acc: {acc}. Retrieval res {retrieval_metrics}")
def start_training(cfg):
set_random_seed(cfg.seed)
n_gpu = hvd.size()
cfg.n_gpu = n_gpu
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
if hvd.rank() != 0:
LOGGER.disabled = True
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), bool(cfg.fp16)))
model = setup_model(cfg, device=device)
model.train()
optimizer = setup_e2e_optimizer(model, cfg)
# Horovod: (optional) compression algorithm.compressin
compression = hvd.Compression.none
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression)
# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
model, optimizer = amp.initialize(
model, optimizer, enabled=cfg.fp16, opt_level='O2',
keep_batchnorm_fp32=True)
# prepare data
tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
train_loader, val_loader = setup_dataloaders(cfg, tokenizer)
eval_loader = mk_video_ret_eval_dataloader(
anno_path=cfg.val_datasets[0].txt,
lmdb_dir=cfg.val_datasets[0].img,
cfg=cfg, tokenizer=tokenizer,
)
# compute the number of steps and update cfg
total_n_examples = len(train_loader.dataset) * cfg.max_n_example_per_group
total_train_batch_size = int(
n_gpu * cfg.train_batch_size *
cfg.gradient_accumulation_steps * cfg.max_n_example_per_group)
cfg.num_train_steps = int(math.ceil(
1. * cfg.num_train_epochs * total_n_examples / total_train_batch_size))
cfg.valid_steps = int(math.ceil(
1. * cfg.num_train_steps / cfg.num_valid /
cfg.min_valid_steps)) * cfg.min_valid_steps
actual_num_valid = int(math.floor(
1. * cfg.num_train_steps / cfg.valid_steps)) + 1
# restore
restorer = TrainingRestorer(cfg, model, optimizer)
global_step = restorer.global_step
TB_LOGGER.global_step = global_step
if hvd.rank() == 0:
LOGGER.info("Saving training meta...")
save_training_meta(cfg)
LOGGER.info("Saving training done...")
TB_LOGGER.create(join(cfg.output_dir, 'log'))
pbar = tqdm(total=cfg.num_train_steps)
model_saver = ModelSaver(join(cfg.output_dir, "ckpt"))
add_log_to_file(join(cfg.output_dir, "log", "log.txt"))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
restorer = NoOp()
if global_step > 0:
pbar.update(global_step)
LOGGER.info(cfg)
LOGGER.info("Starting training...")
LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
LOGGER.info(f" Single-GPU Non-Accumulated batch size = {cfg.train_batch_size}")
LOGGER.info(f" max_n_example_per_group = {cfg.max_n_example_per_group}")
LOGGER.info(f" Accumulate steps = {cfg.gradient_accumulation_steps}")
LOGGER.info(f" Total batch size = #GPUs * Single-GPU batch size * "
f"max_n_example_per_group * Accumulate steps [Image] = {total_train_batch_size}")
LOGGER.info(f" Total #epochs = {cfg.num_train_epochs}")
LOGGER.info(f" Total #steps = {cfg.num_train_steps}")
LOGGER.info(f" Validate every {cfg.valid_steps} steps, in total {actual_num_valid} times")
LOGGER.info(f'Step {global_step}: start validation')
validate(
model, val_loader, eval_loader, cfg, global_step,
eval_filepath=cfg.val_datasets[0].txt)
# quick hack for amp delay_unscale bug
with optimizer.skip_synchronize():
optimizer.zero_grad()
if global_step == 0:
optimizer.step()
debug_step = 3
running_loss = RunningMeter('train_loss')
for step, batch in enumerate(InfiniteIterator(train_loader)):
# forward pass
del batch["caption_ids"]
mini_batch = dict()
for k, v in batch.items():
if k != "visual_inputs":
mini_batch[k] = v
pool_method = cfg.score_agg_func
# could be 1, where only a single clip is used
num_clips = cfg.train_n_clips
assert num_clips == 1, "Support only single clip for now."
num_frm = cfg.num_frm
# (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
bsz = batch["visual_inputs"].shape[0]
new_visual_shape = (bsz, num_clips, num_frm) + batch["visual_inputs"].shape[2:]
visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
model_out = []
for clip_idx in range(num_clips):
# (B, num_frm, C, H, W)
mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
mini_batch["n_examples_list"] = batch["n_examples_list"]
# outputs = forward_step(model, mini_batch, cfg)
outputs = forward_step(model, mini_batch)
model_out.append(outputs)
# the losses are cross entropy and mse, no need to * num_labels
loss_itm = outputs['itm_loss']
loss_itc = outputs['itc_loss']
loss = loss_itm + loss_itc
running_loss(loss.item())
# backward pass
delay_unscale = (step + 1) % cfg.gradient_accumulation_steps != 0
with amp.scale_loss(
loss, optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
zero_none_grad(model)
optimizer.synchronize()
# optimizer
if (step + 1) % cfg.gradient_accumulation_steps == 0:
global_step += 1
# learning rate scheduling
n_epoch = int(1. * total_train_batch_size * global_step
/ total_n_examples)
# learning rate scheduling cnn
lr_this_step = get_lr_sched(
global_step, cfg.decay, cfg.learning_rate,
cfg.num_train_steps, warmup_ratio=cfg.warmup_ratio,
decay_epochs=cfg.step_decay_epochs,
multi_step_epoch=n_epoch)
# Hardcoded param group length
for pg_n, param_group in enumerate(
optimizer.param_groups):
param_group['lr'] = lr_this_step
if step % cfg.log_interval == 0:
TB_LOGGER.add_scalar(
"train/lr", lr_this_step, global_step)
TB_LOGGER.add_scalar('train/loss', running_loss.val, global_step)
# update model params
if cfg.grad_norm != -1:
grad_norm = clip_grad_norm_(
amp.master_params(optimizer),
cfg.grad_norm)
TB_LOGGER.add_scalar(
"train/grad_norm", grad_norm, global_step)
TB_LOGGER.step()
# Check if there is None grad
none_grads = [
p[0] for p in model.named_parameters()
if p[1].requires_grad and p[1].grad is None]
assert len(none_grads) == 0, f"{none_grads}"
with optimizer.skip_synchronize():
optimizer.step()
optimizer.zero_grad()
restorer.step()
pbar.update(1)
# checkpoint
if global_step % cfg.valid_steps == 0:
LOGGER.info(f'Step {global_step}: start validation')
validate(
model, val_loader, eval_loader, cfg, global_step,
eval_filepath=cfg.val_datasets[0].txt)
model_saver.save(step=global_step, model=model)
if global_step >= cfg.num_train_steps:
break
if cfg.debug and global_step >= debug_step:
break
if global_step % cfg.valid_steps != 0:
LOGGER.info(f'Step {global_step}: start validation')
validate(
model, val_loader, eval_loader, cfg, global_step,
eval_filepath=cfg.val_datasets[0].txt)
model_saver.save(step=global_step, model=model)
def get_retrieval_metric_from_bool_matrix(bool_matrix):
""" Calc Recall@K, median rank and mean rank.
Args:
bool_matrix: np array of shape (#txt, #vid), np.bool,
sorted row-wise from most similar to less similar.
The GT position is marked as 1, while all the others are 0,
each row will only have one 1.
Returns:
retrieval_metrics: dict(
R1=.., R5=..., R10=..., MedR=..., MeanR=...
)
"""
num_row = bool_matrix.shape[0] # #rows
row_range, gt_ranks = np.where(bool_matrix == 1)
assert np.allclose(row_range, np.arange(len(row_range))), \
"each row should only a single GT"
retrieval_metrics = dict(
r1=100 * bool_matrix[:, 0].sum() / num_row,
r5=100 * bool_matrix[:, :5].sum() / num_row,
r10=100 * bool_matrix[:, :10].sum() / num_row,
medianR=np.median(gt_ranks+1), # convert to 1-indexed system instead of 0-indexed.
meanR=np.mean(gt_ranks+1)
)
return retrieval_metrics
def get_retrieval_scores(score_matrix, gt_row2col_id_mapping, row_idx2id, col_id2idx):
# rank scores
score_matrix_sorted, indices_sorted = \
torch.sort(score_matrix, dim=1, descending=True) # (#txt, #vid)
# build bool matrix, where the GT position is marked as 1, all the others are 0,
num_row = len(score_matrix)
gt_col_indices = torch.zeros(num_row, 1)
for idx in range(num_row):
gt_col_id = gt_row2col_id_mapping[row_idx2id[idx]]
gt_col_indices[idx, 0] = col_id2idx[gt_col_id]
bool_matrix = indices_sorted == gt_col_indices # (#txt, #vid)
retrieval_metrics = get_retrieval_metric_from_bool_matrix(bool_matrix.numpy())
return retrieval_metrics
def eval_retrieval(vid_txt_score_dicts, gt_txt_id2vid_id, id2data):
"""
Args:
vid_txt_score_dicts: list(dict), each dict is dict(vid_id=..., txt_id=..., score=...)
gt_txt_id2vid_id: dict, ground-truth {txt_id: vid_id}
id2data: dict, {txt_id: single_example_data}
Returns:
"""
# group prediction by txt_id
scores_group_by_txt_ids = defaultdict(list)
for d in vid_txt_score_dicts:
scores_group_by_txt_ids[d["txt_id"]].append(d)
# clean duplicated videos
_scores_group_by_txt_ids = defaultdict(list)
for txt_id, txt_vid_pairs in scores_group_by_txt_ids.items():
added_vid_ids = []
for d in txt_vid_pairs:
if d["vid_id"] not in added_vid_ids:
_scores_group_by_txt_ids[txt_id].append(d)
added_vid_ids.append(d["vid_id"])
scores_group_by_txt_ids = _scores_group_by_txt_ids
num_txt = len(scores_group_by_txt_ids)
any_key = list(scores_group_by_txt_ids.keys())[0]
vid_ids = [d["vid_id"] for d in scores_group_by_txt_ids[any_key]]
num_vid = len(vid_ids)
assert len(set(vid_ids)) == num_vid, "Each caption should be compared to each video only once."
for k, v in scores_group_by_txt_ids.items():
assert num_vid == len(v), "each captions should be compared with the same #videos."
# row/col indices in the score matrix
# *_id are the original ids, *_idx are the matrix indices
txt_id2idx = {txt_id: idx for idx, txt_id in enumerate(scores_group_by_txt_ids)}
vid_id2idx = {vid_id: idx for idx, vid_id in enumerate(vid_ids)}
txt_idx2id = {v: k for k, v in txt_id2idx.items()}
vid_idx2id = {v: k for k, v in vid_id2idx.items()}
# build score (float32) and vid_id (str) matrix
score_matrix = torch.zeros(num_txt, num_vid)
sim_matrix = torch.zeros(num_txt, num_vid)
for txt_id, preds in scores_group_by_txt_ids.items():
txt_idx = txt_id2idx[txt_id]
for p in preds:
vid_idx = vid_id2idx[p["vid_id"]]
score_matrix[txt_idx, vid_idx] = p["score"]
sim_matrix[txt_idx, vid_idx] = p['sim']
# [dxli] discard pairs with low ITC similarity scores
# top_k, indices = torch.topk(sim_matrix, dim=1, k=100)
# new_sim_matrix = torch.zeros_like(sim_matrix)
# new_sim_matrix = new_sim_matrix.scatter(1, indices, top_k)
# score_matrix[new_sim_matrix == 0] = 0
# text to video retrieval, score_matrix--> (#txt, #vid)
# given a text, retrieve most relevant videos
t2v_retrieval_metrics = get_retrieval_scores(
score_matrix, gt_txt_id2vid_id, txt_idx2id, vid_id2idx)
# video to text retrieval, score_matrix--> (#vid, #txt)
# given a video, retrieve most relevant videos
score_matrix = score_matrix.transpose(0, 1)
gt_vid_id2txt_id = {v: k for k, v in gt_txt_id2vid_id.items()}
v2t_retrieval_metrics = get_retrieval_scores(
score_matrix, gt_vid_id2txt_id, vid_idx2id, txt_id2idx)
retrieval_metrics = dict(
text2video=t2v_retrieval_metrics,
video2text=v2t_retrieval_metrics
)
return retrieval_metrics
@torch.no_grad()
def inference_retrieval(model, val_loader, eval_file_path, cfg):
model.eval()
retrieval_res = [] # list(dict): dict(vid_id=..., txt_id=..., score=...)
st = time.time()
eval_bsz = cfg.inference_batch_size if cfg.do_inference else cfg.eval_retrieval_batch_size
LOGGER.info(f"Evaluate retrieval #video per GPU: {len(val_loader)}")
if hvd.rank() == 0:
pbar = tqdm(total=len(val_loader), desc="eval")
for batch in val_loader:
# each batch contains 1 video and N (=1000) captions
n_mini_batches = math.ceil(len(batch["caption_ids"]) / eval_bsz)
vid_id = batch["vid_id"]
for idx in range(n_mini_batches):
# compile shared text part
mini_batch = dict()
for k in ["text_input_ids", "text_input_mask", "labels"]:
if batch[k] is not None:
mini_batch[k] = batch[k][idx * eval_bsz:(idx + 1) * eval_bsz]
else:
mini_batch[k] = None
caption_ids = batch["caption_ids"][idx * eval_bsz:(idx + 1) * eval_bsz]
# bsz = len(caption_ids)
mini_batch["n_examples_list"] = [len(caption_ids)]
num_clips = cfg.inference_n_clips
num_frm = cfg.num_frm
# (B, T=num_clips*num_frm, C, H, W) --> (B, num_clips, num_frm, C, H, W)
new_visual_shape = (1, num_clips, num_frm) + batch["visual_inputs"].shape[2:]
visual_inputs = batch["visual_inputs"].view(*new_visual_shape)
logits = []
sim_scores = []
for clip_idx in range(num_clips):
mini_batch["visual_inputs"] = visual_inputs[:, clip_idx]
if cfg.fp16:
# FIXME not sure why we need to do this explicitly?
mini_batch["visual_inputs"] = mini_batch["visual_inputs"].half()
outputs = forward_inference_step(model, mini_batch)
logits.append(outputs["logits"].cpu())
sim_scores.append(outputs["itc_scores"].cpu())
logits = torch.stack(logits) # (num_frm, B, 1 or 2)
sim_scores = torch.stack(sim_scores)
# FIXME not sure why need to convert dtype explicitly
logits = logits.squeeze().float()
sim_scores = sim_scores.squeeze().float().tolist()
if logits.shape[1] == 2:
# [dxli] uses 1 for positive and 0 for negative.
# therefore we choose dim=1
probs = F.softmax(logits, dim=1)[:, 1].tolist()
else:
raise NotImplementedError('Not supported (unclear purposes)!')
for cap_id, score, sim in zip(caption_ids, probs, sim_scores):
retrieval_res.append(dict(
vid_id=vid_id,
txt_id=cap_id,
score=round(score, 4),
sim=round(sim, 4)
))
if hvd.rank() == 0:
pbar.update(1)
# ###### Saving with Horovod ####################
# dummy sync
_ = None
all_gather_list(_)
n_gpu = hvd.size()
eval_dir = join(cfg.output_dir, f"results_{os.path.splitext(os.path.basename(eval_file_path))[0]}")
os.makedirs(eval_dir, exist_ok=True)
if n_gpu > 1:
# with retrial, as azure blob fails occasionally.
max_save_load_trial = 10
save_trial = 0
while save_trial < max_save_load_trial:
try:
LOGGER.info(f"Save results trial NO. {save_trial}")
save_json(
retrieval_res,
join(eval_dir, f"tmp_results_rank{hvd.rank()}.json"))
break
except Exception as e:
print(f"Saving exception: {e}")
save_trial += 1
# dummy sync
_ = None
all_gather_list(_)
# join results
if n_gpu > 1 and hvd.rank() == 0:
retrieval_res = []
for rk in range(n_gpu):
retrieval_res.extend(load_json(
join(eval_dir, f"tmp_results_rank{rk}.json")))
LOGGER.info('results joined')
if hvd.rank() == 0:
retrieval_metrics = eval_retrieval(
retrieval_res, val_loader.dataset.gt_cap_id2vid_id, val_loader.dataset.id2data)
LOGGER.info(f"validation finished in {int(time.time() - st)} seconds. scores: {retrieval_metrics}")
else:
retrieval_metrics = None
model.train()
return retrieval_res, retrieval_metrics
def start_inference(cfg):
set_random_seed(cfg.seed)
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
if hvd.rank() != 0:
LOGGER.disabled = True
inference_res_dir = join(
cfg.output_dir,
f"results_{os.path.splitext(os.path.basename(cfg.inference_txt_db))[0]}/"
f"step_{cfg.inference_model_step}_{cfg.inference_n_clips}_{cfg.score_agg_func}"
)
if hvd.rank() == 0:
os.makedirs(inference_res_dir, exist_ok=True)
save_json(cfg, join(inference_res_dir, "raw_args.json"),
save_pretty=True)
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), bool(cfg.fp16)))
# overwrite cfg with stored_cfg,
# but skip keys containing the keyword 'inference'
stored_cfg_path = join(cfg.output_dir, "log/args.json")
stored_cfg = edict(load_json(stored_cfg_path))
for k, v in cfg.items():
if k in stored_cfg and "inference" not in k and "output_dir" not in k:
setattr(cfg, k, stored_cfg[k])
# setup models
cfg.model_config = join(cfg.output_dir, "log/model_config.json")
e2e_weights_path = join(
cfg.output_dir, f"ckpt/model_step_{cfg.inference_model_step}.pt")
if exists(e2e_weights_path):
cfg.e2e_weights_path = e2e_weights_path
else:
raise NotImplementedError("Not supporting loading separate weights for inference.")
model = setup_model(cfg, device=device)
model.eval()
# FIXME separate scaling for each loss
model = amp.initialize(
model, enabled=cfg.fp16, opt_level='O2')
global_step = 0
# prepare data
tokenizer = BertTokenizerFast.from_pretrained(cfg.tokenizer_dir)
cfg.data_ratio = 1.
val_loader = mk_video_ret_eval_dataloader(
anno_path=cfg.inference_txt_db,
lmdb_dir=cfg.inference_img_db,
cfg=cfg, tokenizer=tokenizer,
)
LOGGER.info(cfg)
LOGGER.info("Starting inference...")
LOGGER.info(f"***** Running inference with {n_gpu} GPUs *****")
LOGGER.info(f" Batch size = {cfg.inference_batch_size}")
LOGGER.info(f'Step {global_step}: start validation')
ret_results, ret_scores = inference_retrieval(
model, val_loader, cfg.inference_txt_db, cfg)
if hvd.rank() == 0:
save_json(cfg, join(inference_res_dir, "merged_args.json"),
save_pretty=True)
save_json(ret_results, join(inference_res_dir, "results.json"),
save_pretty=True)
save_json(ret_scores, join(inference_res_dir, "scores.json"),
save_pretty=True)
if __name__ == '__main__':
# Initialize Horovod
hvd.init()
input_cfg = shared_configs.get_video_retrieval_args()
if input_cfg.do_inference:
start_inference(input_cfg)
else:
start_training(input_cfg)
================================================
FILE: src/utils/basic_utils.py
================================================
import os
import ujson as json
import zipfile
import numpy as np
import pickle
import pandas as pd
def load_pickle(filename):
with open(filename, "rb") as f:
return pickle.load(f)
def save_pickle(data, filename):
with open(filename, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
def load_json(filename):
with open(filename, "r") as f:
return json.load(f)
def save_json(data, filename, save_pretty=False, sort_keys=False):
with open(filename, "w") as f:
if save_pretty:
f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
else:
json.dump(data, f)
def load_jsonl(filename):
with open(filename, "r") as f:
return [json.loads(l.strip("\n")) for l in f.readlines()]
def save_jsonl(data, filename):
"""data is a list"""
with open(filename, "w") as f:
f.write("\n".join([json.dumps(e) for e in data]))
def concat_json_list(filepaths, save_path):
json_lists = []
for p in filepaths:
json_lists += load_json(p)
save_json(json_lists, save_path)
def save_lines(list_of_str, filepath):
with open(filepath, "w") as f:
f.write("\n".join(list_of_str))
def read_lines(filepath):
with open(filepath, "r") as f:
return [e.strip("\n") for e in f.readlines()]
def mkdirp(p):
if not os.path.exists(p):
os.makedirs(p)
def flat_list_of_lists(l):
"""flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
return [item for sublist in l for item in sublist]
def convert_to_seconds(hms_time):
""" convert '00:01:12' to 72 seconds.
:hms_time (str): time in comma separated string, e.g. '00:01:12'
:return (int): time in seconds, e.g. 72
"""
times = [float(t) for t in hms_time.split(":")]
return times[0] * 3600 + times[1] * 60 + times[2]
def get_video_name_from_url(url):
return url.split("/")[-1][:-4]
def merge_dicts(list_dicts):
merged_dict = list_dicts[0].copy()
for i in range(1, len(list_dicts)):
merged_dict.update(list_dicts[i])
return merged_dict
def l2_normalize_np_array(np_array, eps=1e-5):
"""np_array: np.ndarray, (*, D), where the last dim will be normalized"""
return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps)
def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, exclude_extensions=None,
exclude_dirs_substring=None):
"""make a zip file of root_dir, save it to save_path.
exclude_paths will be excluded if it is a subdir of root_dir.
An enclosing_dir is added is specified.
"""
abs_src = os.path.abspath(src_dir)
with zipfile.ZipFile(save_path, "w") as zf:
for dirname, subdirs, files in os.walk(src_dir):
if exclude_dirs is not None:
for e_p in exclude_dirs:
if e_p in subdirs:
subdirs.remove(e_p)
if exclude_dirs_substring is not None:
to_rm = []
for d in subdirs:
if exclude_dirs_substring in d:
to_rm.append(d)
for e in to_rm:
subdirs.remove(e)
arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:])
zf.write(dirname, arcname)
for filename in files:
if exclude_extensions is not None:
if os.path.splitext(filename)[1] in exclude_extensions:
continue # do not zip it
absname = os.path.join(dirname, filename)
arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:])
zf.write(absname, arcname)
class AverageMeter(object):
"""Computes and stores the average and current/max/min value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.max = -1e10
self.min = 1e10
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.max = -1e10
self.min = 1e10
def update(self, val, n=1):
self.max = max(val, self.max)
self.min = min(val, self.min)
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True):
"""Dissect an array (N, D) into a list a sub-array,
np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept"""
if assert_equal:
assert len(np_array) == sum(lengths)
length_indices = [0, ]
for i in range(len(lengths)):
length_indices.append(length_indices[i] + lengths[i])
if dim == 0:
array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))]
elif dim == 1:
array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
elif dim == 2:
array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))]
else:
raise NotImplementedError
return array_list
def get_ratio_from_counter(counter_obj, threshold=200):
keys = counter_obj.keys()
values = counter_obj.values()
filtered_values = [counter_obj[k] for k in keys if k > threshold]
return float(sum(filtered_values)) / sum(values)
def get_rounded_percentage(float_number, n_floats=2):
return round(float_number * 100, n_floats)
def read_dataframe(pkl_path):
return pd.read_pickle(pkl_path)
def save_frames_grid(img_array, out_path):
import torch
from torchvision.utils import make_grid
from PIL import Image
if len(img_array.shape) == 3:
img_array = img_array.unsqueeze(0)
elif len(img_array.shape) == 5:
b, t, c, h, w = img_array.shape
img_array = img_array.view(-1, c, h, w)
elif len(img_array.shape) == 4:
pass
else:
raise NotImplementedError('Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored.')
assert img_array.shape[1] == 3, "Exepcting input shape of (3, H, W), i.e. RGB-only."
grid = make_grid(img_array)
ndarr = grid.permute(1, 2, 0).to('cpu', torch.uint8).numpy()
img = Image.fromarray(ndarr)
img.save(out_path)
================================================
FILE: src/utils/distributed.py
================================================
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
distributed API using Horovod
Modified from OpenNMT's native pytorch distributed utils
(https://github.com/OpenNMT/OpenNMT-py)
"""
import math
import pickle
import torch
from horovod import torch as hvd
from horovod.torch.mpi_ops import rank, size
def all_reduce_and_rescale_tensors(tensors, rescale_denom):
"""All-reduce and rescale tensors at once (as a flattened tensor)
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
"""
# buffer size in bytes, determine equiv. # of elements based on data type
sz = sum(t.numel() for t in tensors)
buffer_t = tensors[0].new(sz).zero_()
# copy tensors into buffer_t
offset = 0
for t in tensors:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# all-reduce and rescale
hvd.allreduce_(buffer_t[:offset])
buffer_t.div_(rescale_denom)
# copy all-reduced buffer back into tensors
offset = 0
for t in tensors:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
def all_reduce_and_rescale_tensors_chunked(tensors, rescale_denom,
buffer_size=10485760):
"""All-reduce and rescale tensors in chunks of the specified size.
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
buffer_size: all-reduce chunk size in bytes
"""
# buffer size in bytes, determine equiv. # of elements based on data type
buffer_t = tensors[0].new(
math.ceil(buffer_size / tensors[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy tensors into buffer_t
offset = 0
for t in buffer:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# all-reduce and rescale
hvd.allreduce_(buffer_t[:offset])
buffer_t.div_(rescale_denom)
# copy all-reduced buffer back into tensors
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for t in tensors:
sz = t.numel() * t.element_size()
if sz > buffer_size:
# tensor is bigger than buffer, all-reduce and rescale directly
hvd.allreduce_(t)
t.div_(rescale_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [t]
filled = sz
else:
# add tensor to buffer
buffer.append(t)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def broadcast_tensors(tensors, root_rank, buffer_size=10485760):
"""broadcast tensors in chunks of the specified size.
Args:
tensors: list of Tensors to broadcast
root_rank: rank to broadcast
buffer_size: all-reduce chunk size in bytes
"""
# buffer size in bytes, determine equiv. # of elements based on data type
buffer_t = tensors[0].new(
math.ceil(buffer_size / tensors[0].element_size())).zero_()
buffer = []
def broadcast_buffer():
# copy tensors into buffer_t
offset = 0
for t in buffer:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# broadcast
hvd.broadcast_(buffer_t[:offset], root_rank)
# copy all-reduced buffer back into tensors
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for t in tensors:
sz = t.numel() * t.element_size()
if sz > buffer_size:
# tensor is bigger than buffer, broadcast directly
hvd.broadcast_(t, root_rank)
elif filled + sz > buffer_size:
# buffer is full, broadcast and replace buffer with tensor
broadcast_buffer()
buffer = [t]
filled = sz
else:
# add tensor to buffer
buffer.append(t)
filled += sz
if len(buffer) > 0:
broadcast_buffer()
def all_gather_list(data, max_size=4096):
"""Gathers arbitrary data from all nodes into a list."""
world_size = hvd.size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
in_buffer = all_gather_list._in_buffer
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError(
'encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
# FIXME cannot create buffer
out = hvd.allgather(in_buffer.cuda())
results = []
for i in range(0, max_size*world_size, max_size):
out_buffer = out[i:i+max_size]
size = (255 * out_buffer[0].item()) + out_buffer[1].item()
bytes_list = bytes(out_buffer[2:size+2].tolist())
result = pickle.loads(bytes_list)
results.append(result)
return results
def any_broadcast(data, root_rank, max_size=4096):
"""broadcast arbitrary data from root_rank to all nodes."""
if not hasattr(any_broadcast, '_in_buffer') or \
max_size != any_broadcast._in_buffer.size():
any_broadcast._buffer = torch.cuda.ByteTensor(max_size)
buffer_ = any_broadcast._buffer
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError(
'encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
buffer_[0] = enc_size // 255 # this encoding works for max_size < 65k
buffer_[1] = enc_size % 255
buffer_[2:enc_size+2] = torch.ByteTensor(list(enc))
hvd.broadcast_(buffer_, root_rank)
size = (255 * buffer_[0].item()) + buffer_[1].item()
bytes_list = bytes(buffer_[2:size+2].tolist())
result = pickle.loads(bytes_list)
return result
def allgather_object(obj, name=None):
"""
Serializes and allgathers an object from all other processes.
Arguments:
obj: An object capable of being serialized without losing any context.
name: Optional name to use during allgather, will default to the class
type.
Returns:
The list of objects that were allgathered across all ranks.
"""
import io
import cloudpickle
if name is None:
name = type(obj).__name__
def load(byte_array):
buf = io.BytesIO(byte_array.tobytes())
return cloudpickle.load(buf)
b = io.BytesIO()
cloudpickle.dump(obj, b)
t = torch.ByteTensor(bytearray(b.getvalue()))
sz = torch.IntTensor([t.shape[0]])
sizes = hvd.allgather(sz, name=name + '.sz').numpy()
gathered = hvd.allgather(t, name=name + '.t').numpy()
def select(i):
start = sum(sizes[:i])
end = start + sizes[i]
return gathered[start:end]
return [load(select(i)) for i in range(size())]
================================================
FILE: src/utils/grad_ckpt.py
================================================
import torch
import warnings
def detach_variable(inputs):
if isinstance(inputs, tuple):
out = []
for inp in inputs:
x = inp.detach()
x.requires_grad = inp.requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
def check_backward_validity(inputs):
if not any(inp.requires_grad for inp in inputs):
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
for i in range(len(ctx.input_tensors)):
temp = ctx.input_tensors[i]
ctx.input_tensors[i] = temp.detach()
ctx.input_tensors[i].requires_grad = temp.requires_grad
with torch.enable_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
return (None, None) + input_grads
================================================
FILE: src/utils/load_save.py
================================================
"""
saving utilities
"""
import json
import os
from os.path import dirname, exists, join, realpath
import subprocess
from apex import amp
from easydict import EasyDict as edict
import torch
from src.utils.basic_utils import save_json, make_zipfile, load_json
from src.utils.logger import LOGGER
from typing import Any, Dict, Union
from src.modeling.timesformer.helpers import resize_spatial_embedding, resize_temporal_embedding
def save_training_meta(args):
# args is an EasyDict object, treat it the same as a normal dict
os.makedirs(join(args.output_dir, 'log'), exist_ok=True)
os.makedirs(join(args.output_dir, 'ckpt'), exist_ok=True)
# training args
save_args_path = join(args.output_dir, 'log', 'args.json')
save_json(args, save_args_path, save_pretty=True)
# model args
model_config = json.load(open(args.model_config))
save_model_config_path = join(args.output_dir, 'log', 'model_config.json')
save_json(model_config, save_model_config_path, save_pretty=True)
# save a copy of the codebase. !!!Do not store heavy file in your codebase when using it.
code_dir = dirname(dirname(dirname(os.path.realpath(__file__))))
code_zip_filename = os.path.join(args.output_dir, "code.zip")
LOGGER.info(f"Saving code from {code_dir} to {code_zip_filename}...")
make_zipfile(code_dir, code_zip_filename,
enclosing_dir="code",
exclude_dirs_substring="results",
exclude_dirs=["__pycache__", "output", "data", "ext"],
exclude_extensions=[".pyc", ".ipynb", ".swap", ".pt"])
LOGGER.info(f"Saving code done.")
class ModelSaver(object):
def __init__(self, output_dir):
self.output_dir = output_dir
self.max_save_load_trial = 10
def save(self, step, model, optimizer=None, prefix="model"):
model_path = join(self.output_dir, f"{prefix}_step_{step}.pt")
state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v
for k, v in model.state_dict().items()}
# with retrial, as azure blob fails occasionally.
save_trial = 0
while save_trial < self.max_save_load_trial:
try:
LOGGER.info(f"ModelSaver save trial NO. {save_trial}")
torch.save(state_dict, model_path)
if optimizer is not None:
optimizer_state_dict = \
{k: v.cpu() if isinstance(v, torch.Tensor) else v
for k, v in optimizer.state_dict().items()}
dump = {'step': step, 'optimizer': optimizer_state_dict}
torch.save(
dump,
f'{self.output_dir}/{prefix}_step_{step}_train_state.pt')
break
except Exception as e:
save_trial += 1
def load_state_dict_with_pos_embed_resizing(model, loaded_state_dict_or_path,
num_patches, num_frames,
spatial_embed_key='visual_encoder.model.pos_embed',
temporal_embed_key='visual_encoder.model.time_embed',
strict=False,
remove_text_encoder_prefix=False
):
"""operated in-place, no need to return `model`,
Used to load e2e model checkpoints.
remove_text_encoder_prefix: set to True, when finetune downstream models from pre-trained checkpoints.
"""
if isinstance(loaded_state_dict_or_path, str):
loaded_state_dict = torch.load(
loaded_state_dict_or_path, map_location="cpu")
else:
loaded_state_dict = loaded_state_dict_or_path
new_state_dict = loaded_state_dict.copy()
for key in loaded_state_dict:
if 'text_encoder.bert' in key and remove_text_encoder_prefix:
new_key = key.replace('text_encoder.bert','text_encoder')
new_state_dict[new_key] = new_state_dict.pop(key)
loaded_state_dict = new_state_dict
## Resizing spatial embeddings in case they don't match
if num_patches + 1 != loaded_state_dict[spatial_embed_key].size(1):
loaded_state_dict[spatial_embed_key] = resize_spatial_embedding(loaded_state_dict, spatial_embed_key, num_patches)
else:
LOGGER.info('The length of spatial position embedding matches. No need to resize.')
## Resizing time embeddings in case they don't match
if temporal_embed_key in loaded_state_dict and num_frames != loaded_state_dict[temporal_embed_key].size(1):
loaded_state_dict[temporal_embed_key] = resize_temporal_embedding(loaded_state_dict, temporal_embed_key, num_frames)
else:
LOGGER.info('No temporal encoding found. Or the length of temporal position embedding matches. No need to resize.')
model_keys = set([k for k in list(model.state_dict().keys())])
load_keys = set(loaded_state_dict.keys())
toload = {}
mismatched_shape_keys = []
for k in model_keys:
if k in load_keys:
if model.state_dict()[k].shape != loaded_state_dict[k].shape:
mismatched_shape_keys.append(k)
else:
toload[k] = loaded_state_dict[k]
LOGGER.info("You can ignore the keys with `num_batches_tracked` or from task heads")
LOGGER.info("Keys in loaded but not in model:")
diff_keys = load_keys.difference(model_keys)
LOGGER.info(f"In total {len(diff_keys)}, {sorted(diff_keys)}")
LOGGER.info("Keys in model but not in loaded:")
diff_keys = model_keys.difference(load_keys)
LOGGER.info(f"In total {len(diff_keys)}, {sorted(diff_keys)}")
LOGGER.info("Keys in model and loaded, but shape mismatched:")
LOGGER.info(f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}")
model.load_state_dict(toload, strict=strict)
def compare_dict_difference(dict1, dict2, dict1_name="dict1",
dict2_name="dict2",
print_value_diff=True, verbose=False):
"""
Args:
dict1:
dict2:
dict1_name:
dict2_name:
print_value_diff: bool, output dict value difference within shared keys
for dict1 and dict2. In effect only when verbose == True
verbose:
"""
keys1 = set(dict1.keys())
keys2 = set(dict2.keys())
shared_keys = keys1.intersection(keys2)
keys1_unique = keys1.difference(shared_keys)
keys2_unique = keys2.difference(shared_keys)
key_diff_list = list(keys1_unique) + list(keys2_unique)
# value difference in the shared keys in dict1 and dict2
value_diff_dict = {}
for k in shared_keys:
if dict1[k] != dict2[k]:
value_diff_dict[k] = [(dict1_name, dict1[k]), (dict2_name, dict2[k])]
if verbose:
LOGGER.info("=" * 30 + "key difference")
LOGGER.info(f"keys in {dict1_name} but not in {dict2_name}: "
f"total {len(keys1_unique)}, {sorted(keys1_unique)}")
LOGGER.info(f"keys in {dict2_name} but not in {dict1_name}: "
f"total {len(keys2_unique)}, {sorted(keys2_unique)}")
if verbose and print_value_diff:
LOGGER.info("=" * 30 + "value difference")
LOGGER.info(f"{json.dumps(value_diff_dict, indent=4)}")
return value_diff_dict, key_diff_list
def _to_cuda(state):
""" usually load from cpu checkpoint but need to load to cuda """
if isinstance(state, torch.Tensor):
ret = state.cuda() # assume propoerly set py torch.cuda.set_device
if 'Half' in state.type():
ret = ret.float() # apex O2 requires it
return ret
elif isinstance(state, list):
new_state = [_to_cuda(t) for t in state]
elif isinstance(state, tuple):
new_state = tuple(_to_cuda(t) for t in state)
elif isinstance(state, dict):
new_state = {n: _to_cuda(t) for n, t in state.items()}
else:
return state
return new_state
def _to_cpu(state):
""" store in cpu to avoid GPU0 device, fp16 to save space """
if isinstance(state, torch.Tensor):
ret = state.cpu()
if 'Float' in state.type():
ret = ret.half()
return ret
elif isinstance(state, list):
new_state = [_to_cpu(t) for t in state]
elif isinstance(state, tuple):
new_state = tuple(_to_cpu(t) for t in state)
elif isinstance(state, dict):
new_state = {n: _to_cpu(t) for n, t in state.items()}
else:
return state
return new_state
class TrainingRestorer(object):
"""ckpt_dict: a dict contains all optimizers/models"""
def __init__(self, opts, **ckpt_dict):
if exists(opts.output_dir):
restore_opts = json.load(open(
f'{opts.output_dir}/log/args.json', 'r'))
assert opts == edict(restore_opts)
# keep 2 checkpoints in case of corrupted
self.save_path = f'{opts.output_dir}/restore.pt'
self.backup_path = f'{opts.output_dir}/restore_backup.pt'
self.ckpt_dict = ckpt_dict
self.save_steps = opts.save_steps
self.amp = opts.fp16
# since saving to or loading from azure blob fails sometimes
self.max_save_load_trial = 10
if exists(self.save_path) or exists(self.backup_path):
LOGGER.info('found previous checkpoint. try to resume...')
# with retrial, as azure blob fails occasionally.
restore_trial = 0
while restore_trial < self.max_save_load_trial:
LOGGER.info(f"TrainingRestorer restore trial NO. {restore_trial}")
try:
self.restore()
break
except Exception as e:
restore_trial += 1
else:
self.global_step = 0
def step(self):
self.global_step += 1
if self.global_step % self.save_steps == 0:
# with retrial, as azure blob fails occasionally.
save_trial = 0
while save_trial < self.max_save_load_trial:
LOGGER.info(f"TrainingRestorer save trial NO. {save_trial}")
try:
self.save()
break
except Exception as e:
save_trial += 1
def save(self):
checkpoint_to_save = {'global_step': self.global_step}
for k in self.ckpt_dict:
checkpoint_to_save[k] = _to_cpu(self.ckpt_dict[k].state_dict())
if self.amp:
checkpoint_to_save['amp_state_dict'] = amp.state_dict()
if exists(self.save_path):
os.rename(self.save_path, self.backup_path)
torch.save(checkpoint_to_save, self.save_path)
def restore(self):
try:
checkpoint = torch.load(self.save_path)
except Exception:
checkpoint = torch.load(self.backup_path)
self.global_step = checkpoint['global_step']
for k in self.ckpt_dict:
self.ckpt_dict[k].load_state_dict(_to_cuda(checkpoint[k]))
if self.amp:
amp.load_state_dict(checkpoint['amp_state_dict'])
LOGGER.info(f'resume training from step {self.global_step}')
class E2E_TrainingRestorer(object):
def __init__(self, opts, model, optimizer):
if exists(f"{opts.output_dir}/log/args.json"):
restore_opts = json.load(
open(f'{opts.output_dir}/log/args.json', 'r'))
with open(join(
opts.output_dir, 'log',
'restore_args.json'), 'w') as writer:
json.dump(vars(opts), writer, indent=4)
# assert opts == edict(restore_opts)
# keep 2 checkpoints in case of corrupted
self.save_path = f'{opts.output_dir}/restore.pt'
self.backup_path = f'{opts.output_dir}/restore_backup.pt'
self.model = model
self.optimizer = optimizer
self.save_steps = int(opts.save_steps_ratio * opts.num_train_steps)
self.amp = opts.fp16
# since saving to or loading from azure blob fails sometimes
self.max_save_load_trial = 10
if exists(self.save_path) or exists(self.backup_path):
LOGGER.info('found previous checkpoint. try to resume...')
# with retrial, as azure blob fails occasionally.
restore_trial = 0
while restore_trial < self.max_save_load_trial:
LOGGER.info(f"TrainingRestorer restore trial NO. {restore_trial}")
try:
self.restore(opts)
break
except Exception as e:
restore_trial += 1
else:
self.global_step = 0
def step(self):
self.global_step += 1
if self.global_step % self.save_steps == 0:
# with retrial, as azure blob fails occasionally.
save_trial = 0
while save_trial < self.max_save_load_trial:
LOGGER.info(f"TrainingRestorer save trial NO. {save_trial}")
try:
self.save()
break
except Exception as e:
save_trial += 1
def save(self):
checkpoint = {'global_step': self.global_step,
'model_state_dict': _to_cpu(self.model.state_dict()),
'optim_state_dict': _to_cpu(self.optimizer.state_dict())}
if self.amp:
checkpoint['amp_state_dict'] = amp.state_dict()
if exists(self.save_path):
os.rename(self.save_path, self.backup_path)
torch.save(checkpoint, self.save_path)
def restore(self, opts):
try:
checkpoint = torch.load(self.save_path)
except Exception:
checkpoint = torch.load(self.backup_path)
self.global_step = checkpoint['global_step']
self.model.load_state_dict(_to_cuda(checkpoint['model_state_dict']))
self.optimizer.load_state_dict(
_to_cuda(checkpoint['optim_state_dict']))
if self.amp:
amp.load_state_dict(checkpoint['amp_state_dict'])
LOGGER.info(f'resume training from step {self.global_step}')
================================================
FILE: src/utils/logger.py
================================================
"""
references: UNITER
"""
import logging
from tensorboardX import SummaryWriter
_LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
_DATE_FMT = '%m/%d/%Y %H:%M:%S'
logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)
LOGGER = logging.getLogger('__main__') # this is the global logger
def add_log_to_file(log_path):
fh = logging.FileHandler(log_path)
formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT)
fh.setFormatter(formatter)
LOGGER.addHandler(fh)
class TensorboardLogger(object):
def __init__(self):
self._logger = None
self._global_step = 0
def create(self, path):
self._logger = SummaryWriter(path)
def noop(self, *args, **kwargs):
return
def step(self):
self._global_step += 1
@property
def global_step(self):
return self._global_step
@global_step.setter
def global_step(self, step):
self._global_step = step
def log_scalar_dict(self, log_dict, prefix=''):
""" log a dictionary of scalar values"""
if self._logger is None:
return
if prefix:
prefix = f'{prefix}_'
for name, value in log_dict.items():
if isinstance(value, dict):
self.log_scalar_dict(value, self._global_step,
prefix=f'{prefix}{name}')
else:
self._logger.add_scalar(f'{prefix}{name}', value,
self._global_step)
def __getattr__(self, name):
if self._logger is None:
return self.noop
return self._logger.__getattribute__(name)
TB_LOGGER = TensorboardLogger()
class RunningMeter(object):
""" running meteor of a scalar value
(useful for monitoring training loss)
"""
def __init__(self, name, val=None, smooth=0.99):
self._name = name
self._sm = smooth
self._val = val
def __call__(self, value):
self._val = (value if self._val is None
else value*(1-self._sm) + self._val*self._sm)
def __str__(self):
return f'{self._name}: {self._val:.4f}'
@property
def val(self):
return self._val
@property
def name(self):
return self._name
================================================
FILE: src/utils/misc.py
================================================
"""
modified from UNITER
"""
import json
import random
import sys
import torch
import numpy as np
class NoOp(object):
""" useful for distributed training No-Ops """
def __getattr__(self, name):
return self.noop
def noop(self, *args, **kwargs):
return
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def zero_none_grad(model):
for p in model.parameters():
if p.grad is None and p.requires_grad:
p.grad = p.data.new(p.size()).zero_()