Full Code of kdexd/virtex for AI

master ae67b23f86ab cached
99 files
274.8 KB
69.4k tokens
215 symbols
1 requests
Download .txt
Showing preview only (300K chars total). Download the full file or copy to clipboard to get everything.
Repository: kdexd/virtex
Branch: master
Commit: ae67b23f86ab
Files: 99
Total size: 274.8 KB

Directory structure:
gitextract_cto194sv/

├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── configs/
│   ├── _base_bicaptioning_R_50_L1_H1024.yaml
│   ├── backbone_ablations/
│   │   ├── bicaptioning_R_101_L1_H1024.yaml
│   │   ├── bicaptioning_R_50W2X_L1_H1024.yaml
│   │   └── bicaptioning_R_50_L1_H1024.yaml
│   ├── depth_ablations/
│   │   ├── bicaptioning_R_50_L1_H1024.yaml
│   │   ├── bicaptioning_R_50_L2_H1024.yaml
│   │   ├── bicaptioning_R_50_L3_H1024.yaml
│   │   └── bicaptioning_R_50_L4_H1024.yaml
│   ├── detectron2/
│   │   ├── _base_faster_rcnn_R_50_C4_BN.yaml
│   │   ├── _base_mask_rcnn_R_50_FPN.yaml
│   │   ├── coco_segm_default_init_2x.yaml
│   │   ├── lvis_segm_default_init_2x.yaml
│   │   ├── lvis_segm_imagenet_init_2x.yaml
│   │   └── voc_det_default_init_24k.yaml
│   ├── downstream/
│   │   ├── imagenet_clf.yaml
│   │   ├── inaturalist_clf.yaml
│   │   └── voc07_clf.yaml
│   ├── task_ablations/
│   │   ├── bicaptioning_R_50_L1_H2048.yaml
│   │   ├── captioning_R_50_L1_H2048.yaml
│   │   ├── masked_lm_R_50_L1_H2048.yaml
│   │   ├── multilabel_classification_R_50.yaml
│   │   └── token_classification_R_50.yaml
│   └── width_ablations/
│       ├── bicaptioning_R_50_L1_H1024.yaml
│       ├── bicaptioning_R_50_L1_H2048.yaml
│       ├── bicaptioning_R_50_L1_H512.yaml
│       └── bicaptioning_R_50_L1_H768.yaml
├── docs/
│   ├── Makefile
│   ├── _templates/
│   │   └── layout.html
│   ├── conf.py
│   ├── index.rst
│   └── virtex/
│       ├── config.rst
│       ├── data.datasets.rst
│       ├── data.rst
│       ├── data.tokenizers.rst
│       ├── data.transforms.rst
│       ├── factories.rst
│       ├── model_zoo.rst
│       ├── models.rst
│       ├── modules.embedding.rst
│       ├── modules.rst
│       ├── modules.textual_heads.rst
│       ├── modules.visual_backbones.rst
│       ├── optim.lookahead.rst
│       ├── optim.lr_scheduler.rst
│       ├── optim.rst
│       ├── usage/
│       │   ├── downstream.rst
│       │   ├── model_zoo.rst
│       │   ├── pretrain.rst
│       │   └── setup_dependencies.rst
│       ├── utils.beam_search.rst
│       ├── utils.checkpointing.rst
│       ├── utils.common.rst
│       ├── utils.distributed.rst
│       ├── utils.metrics.rst
│       ├── utils.rst
│       └── utils.timer.rst
├── hubconf.py
├── requirements.txt
├── scripts/
│   ├── build_vocabulary.py
│   ├── clf_linear.py
│   ├── clf_voc07.py
│   ├── eval_captioning.py
│   ├── eval_detectron2.py
│   └── pretrain_virtex.py
├── setup.py
└── virtex/
    ├── __init__.py
    ├── config.py
    ├── data/
    │   ├── __init__.py
    │   ├── datasets/
    │   │   ├── captioning.py
    │   │   ├── classification.py
    │   │   ├── coco_captions.py
    │   │   ├── downstream.py
    │   │   └── masked_lm.py
    │   ├── tokenizers.py
    │   └── transforms.py
    ├── factories.py
    ├── model_zoo/
    │   ├── __init__.py
    │   └── model_zoo.py
    ├── models/
    │   ├── __init__.py
    │   ├── captioning.py
    │   ├── classification.py
    │   └── masked_lm.py
    ├── modules/
    │   ├── embedding.py
    │   ├── textual_heads.py
    │   └── visual_backbones.py
    ├── optim/
    │   ├── __init__.py
    │   ├── lookahead.py
    │   └── lr_scheduler.py
    └── utils/
        ├── beam_search.py
        ├── checkpointing.py
        ├── common.py
        ├── distributed.py
        ├── metrics.py
        ├── nucleus_sampling.py
        └── timer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# Code Editors
.vscode
.idea

# Code linters
.mypy_cache

# Datasets and preprocessed files
data/
!virtex/data

# IPython Notebook
.ipynb_checkpoints

# virtualenv
venv/
ENV/

# Temporary scripts to (smoke) test out bits and pieces of code.
scripts/test_*

# Data (symlinks) directory, model checkpoints, tensorboard logs etc.
datasets/
checkpoints/
virtex/utils/assets/
!virtex/data/datasets/
virtex/model_zoo/configs


================================================
FILE: CHANGELOG.md
================================================
CHANGELOG
=========

This CHANGELOG file records changes between different arXiv versions of our paper, and the version of this codebase which should be used to reproduce the results in the corresponding arXiv version. View changes between code versions on the [Releases page](https://github.com/kdexd/virtex/releases).

ArXiv v1 -> v2
==============

**Code version:** `v1.2`.

Fix image captioning results with a modified beam search implementation. _Rest of the downstream task results and pre-trained models are unchanged._


ArXiv v1 -> v2
==============

**Code version:** `v1.0` or `v1.1`.

[ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is our CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0).

While the core motivation and approach is the same, we have made some minor changes in our experiments and evaluation setup. These slightly improve model performances across the board (within decimals). New models are available in [`v1.0` model zoo](http://kdexd.github.io/virtex/virtex/usage/model_zoo.html), however links to old models in `v0.9` will be active till June 30, 2021. We encourage you to use the new models!

We have updated the experiment config files for all changes described below.

Experiment Changes
------------------

### New Feature:

Add a new pretraining task for BERT-style _Masked Language Modeling_. Pre-trained model released in Model Zoo.

### Pre-training:

- The only change during pre-training is that we do not apply weight decay to LayerNorm and biases in input embedding and transformer layers. We apply weight decay to the biases in output linear layer (before softmax).

- Other factors that could affect results:
  - Use official [albumentations.ColorJitter transform](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ColorJitter) that mimics torchvision ColorJitter transform. Earlier I implemented [my own ColorJitter](https://github.com/kdexd/virtex/blob/c19e7fc9b98e98af82286ed1537b6f588eaeac44/virtex/data/transforms.py#L156) because albumentations didn't have one.
  - Use PyTorch Native AMP (Automatic Mixed Precision) instead of NVIDIA Apex.

### Downstream Evaluations:

1. **PASCAL VOC 2007 Linear Classification:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-b4405dd4879a48ef1e5b1e2801035909584a5f1f32f63d5e793fb50dee077b97)
   - Instead of training linear SVMs on 8192-dimensional average pooled features from ResNet-50 (7x7x2048 —> 2x2x2048), like [(Misra et al. 2019)](https://arxiv.org/abs/1905.01235), we directly train SVMs on 2048-dimensional global average pooled features, following recent works like [SwAV (Caron et al. 2020)](https://arxiv.org/abs/2006.09882).
   - We change the pre-processing: resize shortest edge to 256 pixels, and take center crop of 224 pixels.
   - These improve VOC mAP by 1-2 points everywhere, and makes SVM training faster. Since we select best checkpoint based on this metric, all results on other downstream tasks also change in `ArXiv v2` (But the trends remain same.)

2. **ImageNet Linear Evaluation:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-d3dea1e7bf97d0cfca4b59a47c0a9bb81e78b8827654fe0258df9ce2c3f5f41c)
   - Changed random resized crop scale from (20-100%) to (8-100%) for consistency with evaluations in SSL works like MoCo and SwAV.
   - Use cosine LR decay instead of step decay, following SwAV. Improves accuracy by up to 1%.

3. **iNaturalist Fine-tuning:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-09096da78cfcde3a604ce22d80313f0800225d928cce5ef7334b89a382adfe4d)
   - This evaluation is left unchanged across ArXiv versions, but we fixd a typo in image pre-processing step, present in publicly released config.

4. **Detectron2 tasks (COCO and LVIS Instance Segmentation, VOC Detection):**
   - Heavily simplified the script. Updated Detectron2 uses a more memory-efficient SyncBatchNorm and supports AMP.



================================================
FILE: LICENSE
================================================
Copyright (c) 2020, Karan Desai.

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
associated documentation files (the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge, publish, distribute,
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial
portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


================================================
FILE: README.md
================================================
VirTex: Learning Visual Representations from Textual Annotations
================================================================

<h4>
Karan Desai and Justin Johnson
</br>
<span style="font-size: 14pt; color: #555555">
University of Michigan
</span>
</h4>
<hr>

**CVPR 2021** [arxiv.org/abs/2006.06666][1]

**Model Zoo, Usage Instructions and API docs:** [kdexd.github.io/virtex](https://kdexd.github.io/virtex)

VirTex is a pretraining approach which uses semantically dense captions to
learn visual representations. We train CNN + Transformers from scratch on
COCO Captions, and transfer the CNN to downstream vision tasks including
image classification, object detection, and instance segmentation.
VirTex matches or outperforms models which use ImageNet for pretraining -- 
both supervised or unsupervised -- despite using up to 10x fewer images.

![virtex-model](docs/_static/system_figure.jpg)


Get the pretrained ResNet-50 visual backbone from our best performing VirTex
model in one line *without any installation*!

```python
import torch

# That's it, this one line only requires PyTorch.
model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True)
```

### Note (For returning users before January 2021):

The pretrained models in our model zoo have changed from [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0) onwards.
They are slightly better tuned than older models, and reproduce the results in our
CVPR 2021 accepted paper ([arXiv v2](https://arxiv.org/abs/2006.06666v2)). 
Some training and evaluation hyperparams are changed since [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9).
Please refer [`CHANGELOG.md`](https://github.com/kdexd/virtex/blob/master/CHANGELOG.md)


Usage Instructions
------------------

1. [How to setup this codebase?][2]  
2. [VirTex Model Zoo][3]  
3. [How to train your VirTex model?][4]  
4. [How to evaluate on downstream tasks?][5]  

Full documentation is available at [kdexd.github.io/virtex](https://kdexd.github.io/virtex).


Citation
--------

If you find this code useful, please consider citing:

```text
@inproceedings{desai2021virtex,
    title={{VirTex: Learning Visual Representations from Textual Annotations}},
    author={Karan Desai and Justin Johnson},
    booktitle={CVPR},
    year={2021}
}
```

Acknowledgments
---------------

We thank Harsh Agrawal, Mohamed El Banani, Richard  Higgins, Nilesh Kulkarni
and Chris Rockwell for helpful discussions and feedback on the paper. We thank
Ishan Misra for discussions regarding PIRL evaluation protocol; Saining Xie for
discussions about replicating iNaturalist evaluation as MoCo; Ross Girshick and
Yuxin Wu for help with Detectron2 model zoo; Georgia Gkioxari for suggesting
the Instance Segmentation pretraining task ablation; and Stefan Lee for
suggestions on figure aesthetics. We thank Jia Deng for access to extra GPUs
during project development; and UMich ARC-TS team for support with GPU cluster
management. Finally, we thank all the Starbucks outlets in Ann Arbor for many
hours of free WiFi. This work was partially supported by the Toyota Research
Institute (TRI). However, note that this article solely reflects the opinions
and conclusions of its authors and not TRI or any other Toyota entity.


[1]: https://arxiv.org/abs/2006.06666
[2]: https://kdexd.github.io/virtex/virtex/usage/setup_dependencies.html
[3]: https://kdexd.github.io/virtex/virtex/usage/model_zoo.html
[4]: https://kdexd.github.io/virtex/virtex/usage/pretrain.html
[5]: https://kdexd.github.io/virtex/virtex/usage/downstream.html


================================================
FILE: configs/_base_bicaptioning_R_50_L1_H1024.yaml
================================================
# -----------------------------------------------------------------------------
# Base config: VirTex pretraining for our "base" bicaptioning model:
# ResNet-50 + (L = 1, H = 1024) transformer trained for 500K iterations.
# -----------------------------------------------------------------------------
RANDOM_SEED: 0
AMP: true
CUDNN_BENCHMARK: true
CUDNN_DETERMINISTIC: false

DATA:
  ROOT: "datasets/coco"
  TOKENIZER_MODEL: "datasets/vocab/coco_10k.model"
  VOCAB_SIZE: 10000
  UNK_INDEX: 0
  SOS_INDEX: 1
  EOS_INDEX: 2
  MASK_INDEX: 3

  IMAGE_CROP_SIZE: 224
  MAX_CAPTION_LENGTH: 30

  IMAGE_TRANSFORM_TRAIN:
    - "random_resized_crop"
    - "horizontal_flip"
    - "color_jitter"
    - "normalize"

  IMAGE_TRANSFORM_VAL:
    - "smallest_resize"
    - "center_crop"
    - "normalize"

MODEL:
  NAME: "virtex"

  VISUAL:
    NAME: "torchvision::resnet50"
    PRETRAINED: false
    FROZEN: false

  TEXTUAL:
    NAME: "transdec_postnorm::L1_H1024_A16_F4096"
    DROPOUT: 0.1

  DECODER:
    NAME: "beam_search"
    BEAM_SIZE: 5

OPTIM:
  OPTIMIZER_NAME: "sgd"
  SGD_MOMENTUM: 0.9
  WEIGHT_DECAY: 0.0001

  LOOKAHEAD:
    USE: true
    ALPHA: 0.5
    STEPS: 5

  BATCH_SIZE: 256
  CNN_LR: 0.2
  LR: 0.001
  NUM_ITERATIONS: 500000

  WARMUP_STEPS: 10000
  LR_DECAY_NAME: "cosine"

  NO_DECAY: ".*textual.(embedding|transformer).*(norm.*|bias)"
  CLIP_GRAD_NORM: 10.0



================================================
FILE: configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  VISUAL:
    NAME: "torchvision::resnet101"


================================================
FILE: configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  VISUAL:
    NAME: "torchvision::wide_resnet50_2"


================================================
FILE: configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"


================================================
FILE: configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"


================================================
FILE: configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  TEXTUAL:
    NAME: "transdec_postnorm::L2_H1024_A16_F4096"


================================================
FILE: configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  TEXTUAL:
    NAME: "transdec_postnorm::L3_H1024_A16_F4096"


================================================
FILE: configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  TEXTUAL:
    NAME: "transdec_postnorm::L4_H1024_A16_F4096"


================================================
FILE: configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml
================================================
# ----------------------------------------------------------------------------
# Train a Faster R-CNN with ResNet-50 and C4 backbone. This config follows
# Detectron2 format; and is unrelated with our VirTex configs. Params here
# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722).
# ----------------------------------------------------------------------------

INPUT:
  # Input format will always be RGB, consistent with torchvision.
  FORMAT: "RGB"
  MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
  MIN_SIZE_TEST: 800

MODEL:
  META_ARCHITECTURE: "GeneralizedRCNN"

  # Train all layers end-to-end by default.
  BACKBONE:
    NAME: build_resnet_backbone
    FREEZE_AT: 0

  # Fine-tune with SyncBN.
  # STRIDE_IN_1X1 is False for torchvision-like models.
  RESNETS:
    DEPTH: 50
    NORM: SyncBN
    STRIDE_IN_1X1: False

  RPN:
    PRE_NMS_TOPK_TEST: 6000
    POST_NMS_TOPK_TEST: 1000

  # ROI head with extra BN layer after res5 stage.
  ROI_HEADS:
    NAME: "Res5ROIHeadsExtraNorm"

  # ImageNet color mean for torchvision-like models (RGB order).
  PIXEL_MEAN: [123.675, 116.280, 103.530]
  PIXEL_STD: [58.395, 57.120, 57.375]

SOLVER:
  # This is for 8 GPUs, apply linear scaling for 4 GPUs.
  IMS_PER_BATCH: 16
  BASE_LR: 0.02

TEST:
  PRECISE_BN:
    ENABLED: True

VERSION: 2


================================================
FILE: configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml
================================================
# ----------------------------------------------------------------------------
# Train a Mask R-CNN with ResNet-50 and FPN backbone. This config follows
# Detectron2 format; and is unrelated with our VirTex configs. Params here
# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722).
# ----------------------------------------------------------------------------

INPUT:
  # Input format will always be RGB, consistent with torchvision.
  FORMAT: "RGB"
  MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
  MIN_SIZE_TEST: 800

MODEL:
  META_ARCHITECTURE: "GeneralizedRCNN"

  # Train all layers end-to-end by default.
  BACKBONE:
    NAME: "build_resnet_fpn_backbone"
    FREEZE_AT: 0

  # Fine-tune with SyncBN.
  # STRIDE_IN_1X1 is False for torchvision-like models.
  RESNETS:
    DEPTH: 50
    NORM: "SyncBN"
    STRIDE_IN_1X1: False
    OUT_FEATURES: ["res2", "res3", "res4", "res5"]

  FPN:
    IN_FEATURES: ["res2", "res3", "res4", "res5"]

  ANCHOR_GENERATOR:
    # One size for each in feature map
    SIZES: [[32], [64], [128], [256], [512]]
    # Three aspect ratios (same for all in feature maps)
    ASPECT_RATIOS: [[0.5, 1.0, 2.0]]

  RPN:
    IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
    PRE_NMS_TOPK_TRAIN: 2000
    PRE_NMS_TOPK_TEST: 1000

    POST_NMS_TOPK_TRAIN: 1000
    POST_NMS_TOPK_TEST: 1000

  ROI_HEADS:
    NAME: "StandardROIHeads"
    IN_FEATURES: ["p2", "p3", "p4", "p5"]

  ROI_BOX_HEAD:
    NAME: "FastRCNNConvFCHead"
    NUM_FC: 2
    POOLER_RESOLUTION: 7

  ROI_MASK_HEAD:
    NAME: "MaskRCNNConvUpsampleHead"
    NUM_CONV: 4
    POOLER_RESOLUTION: 14

  # ImageNet color mean for torchvision-like models (RGB order).
  # These are in [0-255] range as expected by Detectron2. Rest of our codebase
  # uses [0-1] range; but both are equivalent and consistent.
  PIXEL_MEAN: [123.675, 116.280, 103.530]
  PIXEL_STD: [58.395, 57.120, 57.375]

SOLVER:
  # This is for 8 GPUs, apply linear scaling for 4 GPUs.
  IMS_PER_BATCH: 16
  BASE_LR: 0.02

TEST:
  PRECISE_BN:
    ENABLED: True

VERSION: 2


================================================
FILE: configs/detectron2/coco_segm_default_init_2x.yaml
================================================
# -----------------------------------------------------------------------------
# Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of
# these weight init: random, imagenet (torchvision), virtex or MoCo.
# -----------------------------------------------------------------------------
_BASE_: "_base_mask_rcnn_R_50_FPN.yaml"

DATASETS:
  TRAIN: ("coco_2017_train",)
  TEST: ("coco_2017_val",)

MODEL:
  MASK_ON: True
  # FPN also has SyncBN, as opposed to no norm (usually).
  FPN:
    NORM: "SyncBN"
  
  # This will be ignored, weights will be loaded manually in the script.
  WEIGHTS: ""
  
SOLVER:
  STEPS: (120000, 160000)
  MAX_ITER: 180000
  
VERSION: 2


================================================
FILE: configs/detectron2/lvis_segm_default_init_2x.yaml
================================================
# -----------------------------------------------------------------------------
# Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of
# these weight init: random, virtex or MoCo. (ImageNet init config is separate)
# -----------------------------------------------------------------------------
_BASE_: "_base_mask_rcnn_R_50_FPN.yaml"

DATASETS:
  TRAIN: ("lvis_v1_train",)
  TEST: ("lvis_v1_val",)

DATALOADER:
  SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
  REPEAT_THRESHOLD: 0.001

TEST:
  DETECTIONS_PER_IMAGE: 300  # LVIS allows up to 300.

MODEL:
  MASK_ON: True
  # FPN also has SyncBN, as opposed to no norm (usually).
  FPN:
    NORM: "SyncBN"

  ROI_HEADS:
    NUM_CLASSES: 1203
    SCORE_THRESH_TEST: 0.0001

  # This will be ignored, weights will be loaded manually in the script.
  WEIGHTS: ""

SOLVER:
  STEPS: (120000, 160000)
  MAX_ITER: 180000

VERSION: 2



================================================
FILE: configs/detectron2/lvis_segm_imagenet_init_2x.yaml
================================================
# -----------------------------------------------------------------------------
# Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation
# with weights initialized from supervised ImageNet pretraining (torchvision).
# Key difference is that fine-tuning here happens with BN frozen.
# -----------------------------------------------------------------------------
_BASE_: "_base_mask_rcnn_R_50_FPN.yaml"

DATASETS:
  TRAIN: ("lvis_v1_train",)
  TEST: ("lvis_v1_val",)

DATALOADER:
  SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
  REPEAT_THRESHOLD: 0.001

TEST:
  DETECTIONS_PER_IMAGE: 300  # LVIS allows up to 300.

MODEL:
  MASK_ON: True
  RESNETS:
    NORM: "FrozenBN"

  # Do not tune with SyncBN for ImageNet init from LVIS.
  ROI_HEADS:
    NUM_CLASSES: 1203
    SCORE_THRESH_TEST: 0.0001

  # This will be ignored, weights will be loaded manually in the script.
  WEIGHTS: ""

SOLVER:
  STEPS: (120000, 160000)
  MAX_ITER: 180000

VERSION: 2




================================================
FILE: configs/detectron2/voc_det_default_init_24k.yaml
================================================
# -----------------------------------------------------------------------------
# Train a Faster R-CNN with R50-C4 backbone on VOC07+12 detection with any of
# these weight init: random, imagenet (torchvision), virtex or MoCo.
# -----------------------------------------------------------------------------
_BASE_: "_base_faster_rcnn_R_50_C4_BN.yaml"

DATASETS:
  TRAIN: ("voc_2007_trainval", "voc_2012_trainval")
  TEST: ("voc_2007_test",)

INPUT:
  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
  MIN_SIZE_TEST: 800

MODEL:
  MASK_ON: False
  ROI_HEADS:
    NUM_CLASSES: 20

  # This will be ignored, weights will be loaded manually in the script.
  WEIGHTS: ""

SOLVER:
  STEPS: (18000, 22000)
  MAX_ITER: 24000
  WARMUP_ITERS: 100

VERSION: 2


================================================
FILE: configs/downstream/imagenet_clf.yaml
================================================
RANDOM_SEED: 0
# Don't need AMP to train a tiny linear layer.
AMP: false
CUDNN_BENCHMARK: true
CUDNN_DETERMINISTIC: false

DATA:
  ROOT: "datasets/imagenet"
  IMAGE_TRANSFORM_TRAIN:
    - "random_resized_crop::{'scale': (0.08, 1.0)}"
    - "horizontal_flip"
    - "normalize"
  IMAGE_TRANSFORM_VAL:
    - "smallest_resize"
    - "center_crop"
    - "normalize"

MODEL:
  VISUAL:
    FROZEN: true

OPTIM:
  BATCH_SIZE: 256
  SGD_MOMENTUM: 0.9
  WEIGHT_DECAY: 0.0
  NO_DECAY: "none"
  LOOKAHEAD:
    USE: false

  LR: 0.3
  WARMUP_STEPS: 0
  LR_DECAY_NAME: "cosine"
  NUM_ITERATIONS: 500500  # 100 epochs


================================================
FILE: configs/downstream/inaturalist_clf.yaml
================================================
RANDOM_SEED: 0
AMP: true
CUDNN_BENCHMARK: true
CUDNN_DETERMINISTIC: false

DATA:
  ROOT: "datasets/inaturalist"
  IMAGE_TRANSFORM_TRAIN:
    - "random_resized_crop::{'scale': (0.08, 1.0)}"
    - "horizontal_flip"
    - "normalize"
  IMAGE_TRANSFORM_VAL:
    - "smallest_resize"
    - "center_crop"
    - "normalize"

MODEL:
  VISUAL:
    FROZEN: false
    
OPTIM:
  BATCH_SIZE: 256
  SGD_MOMENTUM: 0.9
  WEIGHT_DECAY: 0.0001
  NO_DECAY: "none"
  LOOKAHEAD:
    USE: false

  LR: 0.025
  WARMUP_STEPS: 0
  LR_DECAY_NAME: multistep
  LR_GAMMA: 0.1
  LR_STEPS:
    - 119700  # 70 epochs
    - 153900  # 90 epochs
  NUM_ITERATIONS: 171000  # 100 epochs


================================================
FILE: configs/downstream/voc07_clf.yaml
================================================
RANDOM_SEED: 0
DATA:
  ROOT: datasets/VOC2007
  IMAGE_TRANSFORM_TRAIN:
    - smallest_resize
    - center_crop
    - normalize
  IMAGE_TRANSFORM_VAL:
    - smallest_resize
    - center_crop
    - normalize

OPTIM:
  # Only used for feature extraction, doesn't mean much.
  BATCH_SIZE: 128


================================================
FILE: configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  TEXTUAL:
    NAME: "transdec_postnorm::L1_H2048_A32_F8192"


================================================
FILE: configs/task_ablations/captioning_R_50_L1_H2048.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  NAME: "captioning"
  TEXTUAL:
    NAME: "transdec_postnorm::L1_H2048_A32_F8192"


================================================
FILE: configs/task_ablations/masked_lm_R_50_L1_H2048.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  NAME: "masked_lm"
  TEXTUAL:
    NAME: "transdec_postnorm::L1_H2048_A32_F8192"


================================================
FILE: configs/task_ablations/multilabel_classification_R_50.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

DATA:
  VOCAB_SIZE: 81

MODEL:
  NAME: "multilabel_classification"
  TEXTUAL:
    NAME: "none"

OPTIM:
  NO_DECAY: "none"


================================================
FILE: configs/task_ablations/token_classification_R_50.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  NAME: "token_classification"
  TEXTUAL:
    NAME: "none"

OPTIM:
  NO_DECAY: "none"


================================================
FILE: configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"


================================================
FILE: configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  TEXTUAL:
    NAME: "transdec_postnorm::L1_H2048_A32_F8192"


================================================
FILE: configs/width_ablations/bicaptioning_R_50_L1_H512.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  TEXTUAL:
    NAME: "transdec_postnorm::L1_H512_A8_F2048"


================================================
FILE: configs/width_ablations/bicaptioning_R_50_L1_H768.yaml
================================================
_BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
  TEXTUAL:
    NAME: "transdec_postnorm::L1_H768_A12_F3072"


================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line.
SPHINXOPTS    =
SPHINXBUILD   = sphinx-build
SOURCEDIR     = .
BUILDDIR      = ../../virtex-sphinx

# Put it first so that "make" without argument is like "make help".
help:
	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)


================================================
FILE: docs/_templates/layout.html
================================================
{% extends "!layout.html" %}

{% block htmltitle %}

    <!-- Global site tag (gtag.js) - Google Analytics -->
    <script async src="https://www.googletagmanager.com/gtag/js?id=UA-120523111-2"></script>
    <script>
    window.dataLayer = window.dataLayer || [];
    function gtag(){dataLayer.push(arguments);}
    gtag('js', new Date());

    gtag('config', 'UA-120523111-2');
    </script>

    <link href="https://fonts.googleapis.com/css?family=Inconsolata&display=swap" rel="stylesheet">
    <link href="https://fonts.googleapis.com/css?family=Ubuntu+Mono&display=swap" rel="stylesheet">

{{ super() }}
{% endblock %}


================================================
FILE: docs/conf.py
================================================
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# http://www.sphinx-doc.org/en/master/config

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import inspect
import os
import sys

sys.path.insert(0, os.path.abspath("../"))


# -- Project information -----------------------------------------------------

project = "virtex"
copyright = "2021, Karan Desai and Justin Johnson"
author = "Karan Desai"

# The full version, including alpha/beta/rc tags
release = "1.4"


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
    "sphinx.ext.autodoc",
    "sphinx.ext.coverage",
    "sphinx.ext.doctest",
    "sphinx.ext.linkcode",
    "sphinx.ext.napoleon",
    "sphinx.ext.autosummary",
    "sphinx.ext.coverage",
    "sphinx.ext.intersphinx",
    "sphinx.ext.mathjax",
    "sphinx_copybutton",
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]

# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = ".rst"

# The master toctree document.
master_doc = "index"

# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# This version is used underneath the title on the index page.
version = "1.4"
# The following is used if you need to also include a more detailed version.
release = "1.4"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = "en"

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ["_build"]

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"

# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False

numpydoc_show_class_members = False


# -- Options for HTML output ----------------------------------------------

# The theme to use for HTML and HTML Help pages.  See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]


# -- Autodoc configuration ------------------------------------------------

autodoc_default_options = {
    "members": True,
    "member-order": "bysource",
    "private-members": True,
    "show-inheritance": True,
}


# -- Intersphinx configuration --------------------------------------------

intersphinx_mapping = {
    "torch": ("https://pytorch.org/docs/stable/", None),
    "albumentations": ("https://albumentations.readthedocs.io/en/latest/", None),
}

# -- Miscellaneous Extra Tweaks -------------------------------------------

# make github links resolve
def linkcode_resolve(domain, info):
    """
    Determine the URL corresponding to Python object
    This code is from
    https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L290
    and https://github.com/Lasagne/Lasagne/pull/262
    """
    if domain != "py":
        return None

    modname = info["module"]
    fullname = info["fullname"]

    submod = sys.modules.get(modname)
    if submod is None:
        return None

    obj = submod
    for part in fullname.split("."):
        try:
            obj = getattr(obj, part)
        except:  # noqa: E722
            return None

    try:
        fn = inspect.getsourcefile(obj)
    except:  # noqa: E722
        fn = None
    if not fn:
        return None

    try:
        source, lineno = inspect.getsourcelines(obj)
    except:  # noqa: E722
        lineno = None

    if lineno:
        linespec = "#L%d-L%d" % (lineno, lineno + len(source) - 1)
    else:
        linespec = ""

    filename = info["module"].replace(".", "/")
    return f"https://github.com/kdexd/virtex/blob/master/{filename}.py{linespec}"


================================================
FILE: docs/index.rst
================================================
.. raw:: html

    <h1 style="text-align: center">
    VirTex: Learning Visual Representations from Textual Annotations
    </h1>
    <h4 style="text-align: center">
    Karan Desai and Justin Johnson
    </br>
    <span style="font-size: 14pt; color: #555555">
    University of Michigan
    </span>
    </h4>
    <hr>

    <h4 style="text-align: center">
    Abstract
    </h4>

    <p style="text-align: justify">
    The de-facto approach to many vision tasks is to start from pretrained
    visual representations, typically learned via supervised training on
    ImageNet. Recent methods have explored unsupervised pretraining to scale to
    vast quantities of unlabeled images. In contrast, we aim to learn
    high-quality visual representations from fewer images. To this end we
    revisit supervised pretraining, and seek data-efficient alternatives to
    classification-based pretraining. We propose VirTex -- a pretraining
    approach using semantically dense captions to learn visual representations.
    We train convolutional networks from scratch on COCO Captions, and transfer
    them to downstream recognition tasks including image classification, object
    detection, and instance segmentation. On all tasks, VirTex yields features
    that match or exceed those learned on ImageNet -- supervised or unsupervised
    -- despite using up to ten times fewer images.
    </p>

**CVPR 2021. Paper available at:** `arxiv.org/abs/2006.06666 <https://arxiv.org/abs/2006.06666>`_.

**Code available at:** `github.com/kdexd/virtex <https://github.com/kdexd/virtex>`_.

.. image:: _static/system_figure.jpg


Get the pretrained ResNet-50 visual backbone from our best performing VirTex
model in one line *without any installation*!

.. code-block:: python

    import torch

    # That's it, this one line only requires PyTorch.
    model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True)


More details in :doc:`virtex/usage/model_zoo`. Next, dive deeper into our
code with User Guide and API References!


User Guide
----------

.. toctree::
    :maxdepth: 2

    virtex/usage/setup_dependencies
    virtex/usage/model_zoo
    virtex/usage/pretrain
    virtex/usage/downstream


API Reference
-------------

.. toctree::
    :maxdepth: 2

    virtex/config
    virtex/factories
    virtex/data
    virtex/models
    virtex/modules
    virtex/optim
    virtex/utils
    virtex/model_zoo


Citation
--------

If you find this code useful, please consider citing:

.. code-block:: text

    @inproceedings{desai2021virtex,
        title={{VirTex: Learning Visual Representations from Textual Annotations}},
        author={Karan Desai and Justin Johnson},
        booktitle={CVPR},
        year={2021}
    }


Acknowledgments
---------------

We thank Harsh Agrawal, Mohamed El Banani, Richard  Higgins, Nilesh Kulkarni
and Chris Rockwell for helpful discussions and feedback on the paper. We thank
Ishan Misra for discussions regarding PIRL evaluation protocol; Saining Xie for
discussions about replicating iNaturalist evaluation as MoCo; Ross Girshick and
Yuxin Wu for help with Detectron2 model zoo; Georgia Gkioxari for suggesting
the Instance Segmentation pretraining task ablation; and Stefan Lee for
suggestions on figure aesthetics. We thank Jia Deng for access to extra GPUs
during project development; and UMich ARC-TS team for support with GPU cluster
management. Finally, we thank all the Starbucks outlets in Ann Arbor for many
hours of free WiFi. This work was partially supported by the Toyota Research
Institute (TRI). However, note that this article solely reflects the opinions
and conclusions of its authors and not TRI or any other Toyota entity.


Indices and Tables
------------------

* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`


================================================
FILE: docs/virtex/config.rst
================================================
virtex.config
=============

.. raw:: html

    <hr>

.. automodule:: virtex.config


Config References
-----------------

.. literalinclude:: ../../virtex/config.py
  :language: python
  :linenos:
  :lines: 42-210
  :dedent: 8


================================================
FILE: docs/virtex/data.datasets.rst
================================================
virtex.data.datasets
====================

.. raw:: html

    <hr>

Pretraining Datasets
--------------------

.. automodule:: virtex.data.datasets.coco_captions

.. automodule:: virtex.data.datasets.captioning

.. automodule:: virtex.data.datasets.classification

------------------------------------------------------------------------------

Downstream Datasets
-------------------

.. automodule:: virtex.data.datasets.downstream


================================================
FILE: docs/virtex/data.rst
================================================
virtex.data
===========

.. raw:: html

    <hr>


.. toctree::

    data.datasets
    data.tokenizers
    data.transforms


================================================
FILE: docs/virtex/data.tokenizers.rst
================================================
virtex.data.tokenizers
======================

.. raw:: html

    <hr>

.. automodule:: virtex.data.tokenizers


================================================
FILE: docs/virtex/data.transforms.rst
================================================
virtex.data.transforms
======================

.. raw:: html

    <hr>

.. automodule:: virtex.data.transforms


================================================
FILE: docs/virtex/factories.rst
================================================
virtex.factories
================

.. raw:: html

    <hr>

.. First only include the top-level module, and base class docstrings.

.. automodule:: virtex.factories
    :no-members:

.. autoclass:: virtex.factories.Factory


------------------------------------------------------------------------------

Dataloading-related Factories
-----------------------------

.. autoclass:: virtex.factories.TokenizerFactory
    :members: from_config

.. autoclass:: virtex.factories.ImageTransformsFactory
    :members: from_config

.. autoclass:: virtex.factories.PretrainingDatasetFactory
    :members: from_config

.. autoclass:: virtex.factories.DownstreamDatasetFactory
    :members: from_config

------------------------------------------------------------------------------

Modeling-related Factories
--------------------------

.. autoclass:: virtex.factories.VisualBackboneFactory
    :members: from_config

.. autoclass:: virtex.factories.TextualHeadFactory
    :members: from_config

.. autoclass:: virtex.factories.PretrainingModelFactory
    :members: from_config

------------------------------------------------------------------------------

Optimization-related Factories
------------------------------

.. autoclass:: virtex.factories.OptimizerFactory
    :members: from_config

.. autoclass:: virtex.factories.LRSchedulerFactory
    :members: from_config


================================================
FILE: docs/virtex/model_zoo.rst
================================================
virtex.model_zoo
================

.. raw:: html

    <hr>

.. automodule:: virtex.model_zoo.model_zoo


================================================
FILE: docs/virtex/models.rst
================================================
virtex.models
=============

.. raw:: html

    <hr>

.. automodule:: virtex.models.classification

-------------------------------------------------------------------------------

.. automodule:: virtex.models.captioning

-------------------------------------------------------------------------------

.. automodule:: virtex.models.masked_lm


================================================
FILE: docs/virtex/modules.embedding.rst
================================================
virtex.modules.embedding
========================

.. raw:: html

    <hr>

.. automodule:: virtex.modules.embedding


================================================
FILE: docs/virtex/modules.rst
================================================
virtex.modules
==============

.. raw:: html

    <hr>

.. toctree::

    modules.embedding
    modules.visual_backbones
    modules.textual_heads


================================================
FILE: docs/virtex/modules.textual_heads.rst
================================================
virtex.modules.textual_heads
============================

.. raw:: html

    <hr>

.. automodule:: virtex.modules.textual_heads


================================================
FILE: docs/virtex/modules.visual_backbones.rst
================================================
virtex.modules.visual_backbones
===============================

.. raw:: html

    <hr>

.. automodule:: virtex.modules.visual_backbones


================================================
FILE: docs/virtex/optim.lookahead.rst
================================================
virtex.optim.lookahead
======================

.. raw:: html

    <hr>

.. automodule:: virtex.optim.lookahead


================================================
FILE: docs/virtex/optim.lr_scheduler.rst
================================================
virtex.optim.lr_scheduler
=========================

.. raw:: html

    <hr>

.. automodule:: virtex.optim.lr_scheduler


================================================
FILE: docs/virtex/optim.rst
================================================
virtex.optim
============

.. raw:: html

    <hr>

.. toctree::

    optim.lookahead
    optim.lr_scheduler


================================================
FILE: docs/virtex/usage/downstream.rst
================================================
How to evaluate on downstream tasks?
====================================

In our paper, we evaluate our pretrained VirTex models on seven different
downstream tasks. Our codebase supports all of these evaluations. Throughout
this documentation, we consider a specific example of our VirTex pretrained
model being evaluated for ensuring filepath uniformity in the following example
command snippets. Paths can be trivially adjusted for any other VirTex model;
evaluating the baselines (MoCo, ImageNet-supervised, Random Init) require
additional changes in commands, explained in the last sub-section.

As an example, consider a pretraining job for our best performing VirTex model
(``width_ablations/bicaptioning_R_50_L1_H2048.yaml``). The serialization
directory might look something like this:

.. code-block:: text

    /tmp/bicaptioning_R_50_L1_H2048
        pretrain_config.yaml
        log-rank0.txt    # stdout/stderr per GPU process
        log-rank1.txt
        ...
        log-rank7.txt
        checkpoint_2000.pth
        checkpoint_4000.pth
        ...
        checkpoint_498000.pth
        checkpoint_500000.pth    # serialized checkpoints
        train_captioning_forward/
            events.out.* ...    # tensorboard logs
        ...

We evaluate all checkpoints on **PASCAL VOC 2007 Linear Classification**, and
then evaluate the best checkpoint (here, it was iteration 500000) on all other
downstream tasks.


PASCAL VOC 2007 Linear Classification
-------------------------------------

Evaluate a single VirTex pretrained checkpoint on VOC 2007 ``trainval`` split:

.. code-block:: shell

    python scripts/clf_voc07.py \
        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \
        --down-config configs/downstream/voc07_clf.yaml \
        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \
        --weight-init virtex \
        --num-gpus-per-machine 1 \
        --cpu-workers 4 \
        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048

To evaluate recent 100 checkpoints in the sub-directory, this command can be
looped over as follows:

.. code-block:: shell

    for ((iter = 300000; iter <= 500000; iter+=2000)); do
        # add command with `checkpoint_$iter.pth`        
    done

This script write metric to tensorboard logs in the same pretraining directory,
all VOC07 mAP curves appear together with pretraining loss curves.

-------------------------------------------------------------------------------

ImageNet Linear Classification
------------------------------

We train a linear classifier on 2048-dimensional global average pooled features
extracted from a frozen visual backbone. Evaluate a checkpoint (for example,
iteration 500000) on this task as:

.. code-block:: shell

    python scripts/clf_linear.py \
        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \
        --down-config configs/downstream/imagenet_clf.yaml \
        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \
        --weight-init virtex \
        --num-gpus-per-machine 8 \
        --cpu-workers 4 \
        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/imagenet_500000 \
        --checkpoint-every 5005  # 1 epoch of ImageNet

-------------------------------------------------------------------------------

Instance Segmentation (and Object Detection) on COCO
----------------------------------------------------

Train a Mask R-CNN with FPN backbone for COCO Instance Segmentation (and Object
Detection, because it also has a box head) by initializing the backbone from
VirTex pretrained weights:

.. code-block:: shell

    python scripts/eval_detectron2.py \
        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \
        --d2-config configs/detectron2/coco_segm_default_init_2x.yaml \
        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \
        --weight-init virtex \
        --num-gpus-per-machine 8 \
        --cpu-workers 2 \
        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/coco_segm_500000 \
        --checkpoint-every 5000

.. note::

    1. This script periodically serializes checkpoints but skips validation
       step during training for saving time; to evaluate a serialized checkpoint
       and write results to tensorboard, provide it as ``--checkpoint-path`` and
       additional flags ``--resume --eval-only``.

    2. Note that ``--d2-config`` here is in Detectron2 format, and not our
       package :class:`~virtex.config.Config`.

    These points are applicable for all tasks described below.

-------------------------------------------------------------------------------

Instance Segmentation on LVIS
-----------------------------

Train a Mask R-CNN with FPN backbone for LVIS Instance Segmentation by
initializing the backbone from VirTex pretrained weights:

.. code-block:: shell

    python scripts/eval_detectron2.py \
        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \
        --d2-config configs/detectron2/lvis_segm_default_init_2x.yaml \
        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \
        --weight-init virtex \
        --num-gpus-per-machine 8 \
        --cpu-workers 2 \
        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/lvis_segm_500000 \
        --checkpoint-every 5000

-------------------------------------------------------------------------------

Object Detection on PASCAL VOC 2007+12
--------------------------------------

Train a Faster R-CNN with C4 backbone for PASCAL VOC 2007+12 Object Detection
by initializing the backbone from VirTex pretrained weights:

.. code-block:: shell

    python scripts/eval_detectron2.py \
        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \
        --d2-config configs/detectron2/voc_det_default_init_24k.yaml \
        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \
        --weight-init virtex \
        --num-gpus-per-machine 8 \
        --cpu-workers 2 \
        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/voc_det_500000 \
        --checkpoint-every 2500

-------------------------------------------------------------------------------

iNaturalist 2018 Fine-Grained Classification
--------------------------------------------

Fine-tune the VirTex pretrained visual backbone end-to-end on iNaturalist 2018
dataset:

.. code-block:: shell

    python scripts/clf_linear.py \
        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \
        --down-config configs/downstream/inaturalist_clf.yaml \
        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \
        --weight-init virtex \
        --num-gpus-per-machine 8 \
        --cpu-workers 4 \
        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/inaturalist_500000 \
        --checkpoint-every 1710  # 1 epoch of iNaturalist

-------------------------------------------------------------------------------

Image Captioning on COCO Captions val2017
-----------------------------------------

Evaluate a pretrained VirTex model on image captioning for COCO Captions val2017
split (reporting CIDEr and SPICE metics):

.. code-block:: shell

    python scripts/eval_captioning.py \
        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \
        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \
        --calc-metrics \
        --num-gpus-per-machine 1 \
        --cpu-workers 4

-------------------------------------------------------------------------------

Running Image Captioning Inference on Arbitrary Images
------------------------------------------------------

The above script can be used for generating captions for any images in a directory.
Replace certain commands as follows:

.. code-block:: shell

    python scripts/eval_captioning.py \
        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \
        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \
        --data-root /path/to/images_dir \
        --output /path/to/save/predictions.json \
        --num-gpus-per-machine 1 \
        --cpu-workers 4

This script will save predictions in JSON format. Since our goal is to not
improve image captioning, these models may not generate the best captions.


================================================
FILE: docs/virtex/usage/model_zoo.rst
================================================
VirTex Model Zoo
================

We provide a collection of pretrained model weights and corresponding config
names in this model zoo. Tables contain partial paths to config files for each
model, download link for pretrained weights and for reference -- VOC07 mAP and
ImageNet top-1 accuracy.

The simplest way to download and use a *full* pretrained model (including both,
the visual backbone and the textual head) is through :doc:`../model_zoo` API as
follows. This code snippet works from anywhere, and does not require to be
executed from project root.

.. code-block:: python

    # Get our full best performing VirTex model:
    import virtex.model_zoo as mz
    model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True)

    # Optionally extract the torchvision-like visual backbone (with ``avgpool``
    # and ``fc`` layers replaced with ``nn.Identity`` module).
    cnn = model.visual.cnn

Alternatively, weights can be manually downloaded from links below, and this
can be executed from the project root:

.. code-block:: python

    from virtex.config import Config
    from virtex.factories import PretrainingModelFactory
    from virtex.utils.checkpointing import CheckpointManager

    # Get the best performing VirTex model:
    _C = Config("configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml")
    model = PretrainingModelFactory.from_config(_C)

    CheckpointManager(model=model).load("/path/to/downloaded/weights.pth")

    # Optionally extract the torchvision-like visual backbone (with ``avgpool``
    # and ``fc`` layers replaced with ``nn.Identity`` module).
    cnn = model.visual.cnn


The pretrained ResNet-50 visual backbone of our best performing model
(``width_ablations/bicaptioning_R_50_L1_H2048.yaml``) can be loaded in a single
line, *without following any installation steps* (only requires PyTorch v1.5):

.. code-block:: python

    import torch

    model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True)

    # This is a torchvision-like resnet50 model, with ``avgpool`` and ``fc``
    # layers replaced with ``nn.Identity`` module.
    image_batch = torch.randn(1, 3, 224, 224)  # batch tensor of one image.
    features_batch = model(image_batch)  # shape: (1, 2048, 7, 7)

-------------------------------------------------------------------------------

Pretraining Task Ablations
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. raw:: html

    <style type="text/css">
    .tg  {border-collapse:collapse;border-spacing:0;}
    .tg td{border-color:black;border-style:solid;border-width:1px;
    overflow:hidden;padding:10px 5px;word-break:normal;}
    .tg th{border-color:black;border-style:solid;border-width:1px;
    font-weight:normal;overflow:hidden;padding:10px 5px;word-break:normal;}
    .tg .tg-zlqz{background-color:#d5d5d5;border-color:inherit;font-weight:bold;text-align:center;vertical-align:center}
    .tg .tg-c3ow{border-color:inherit;text-align:center;vertical-align:top}
    .tg .tg-c3ow a{color: darkgreen; text-decoration: none; border-bottom: 1px dashed green;text-underline-position: under;
    .tg .tg-c3ow a:hover{font-weight: 700;border-bottom: 1px solid green;}
    .tg .tg-0pky{border-color:inherit;text-align:left;vertical-align:top}
    @media screen and (max-width: 767px) {.tg {width: auto !important;}.tg col {width: auto !important;}.tg-wrap {overflow-x: auto;-webkit-overflow-scrolling: touch;}}</style>
    <div class="tg-wrap"><table class="tg">
    <tbody>
    <tr>
        <td class="tg-zlqz">Model Config Name</td>
        <td class="tg-zlqz">VOC07<br>mAP</td>
        <td class="tg-zlqz">ImageNet<br>Top-1 Acc.</td>
        <td class="tg-zlqz">Model URL</td>
    </tr>
    <tr>
        <td class="tg-0pky">task_ablations/bicaptioning_R_50_L1_H2048.yaml</td>
        <td class="tg-c3ow">88.7</td>
        <td class="tg-c3ow">53.8</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/mbeeso8wyieq8wy/bicaptioning_R_50_L1_H2048.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky">task_ablations/captioning_R_50_L1_H2048.yaml</td>
        <td class="tg-c3ow">88.6</td>
        <td class="tg-c3ow">50.8</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/r6zen9k43m5oo58/captioning_R_50_L1_H2048.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky">task_ablations/token_classification_R_50.yaml</td>
        <td class="tg-c3ow">88.8</td>
        <td class="tg-c3ow">48.6</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/o4p9lki505r0mef/token_classification_R_50.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky">task_ablations/multilabel_classification_R_50.yaml</td>
        <td class="tg-c3ow">86.2</td>
        <td class="tg-c3ow">46.2</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/hbspp3jv3u8h3bc/multilabel_classification_R_50.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky">task_ablations/masked_lm_R_50_L1_H2048.yaml</td>
        <td class="tg-c3ow">86.4</td>
        <td class="tg-c3ow">46.7</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/ldzrk6vem4mg6bl/masked_lm_R_50_L1_H2048.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    </tbody>
    </table></div>


Width Ablations
^^^^^^^^^^^^^^^

.. raw:: html

    <div class="tg-wrap"><table class="tg">
    <tbody>
    <tr>
        <td class="tg-zlqz">Model Config Name</td>
        <td class="tg-zlqz">VOC07<br>mAP</td>
        <td class="tg-zlqz">ImageNet<br>Top-1 Acc.</td>
        <td class="tg-zlqz">Model URL</td>
    </tr>
    <tr>
        <td class="tg-0pky">width_ablations/bicaptioning_R_50_L1_H512.yaml</td>
        <td class="tg-c3ow">88.4</td>
        <td class="tg-c3ow">51.8</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/o9fr69jjqfn8a65/bicaptioning_R_50_L1_H512.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky"><span style="font-weight:400;font-style:normal">width_ablations/bicaptioning_R_50_L1_H768.yaml</span></td>
        <td class="tg-c3ow">88.3</td>
        <td class="tg-c3ow">52.3</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/1zxglqrrbfufv9d/bicaptioning_R_50_L1_H768.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky"><span style="font-weight:400;font-style:normal">width_ablations/bicaptioning_R_50_L1_H1024.yaml</span></td>
        <td class="tg-c3ow">88.3</td>
        <td class="tg-c3ow">53.2</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/pdat4tvhnqxel64/bicaptioning_R_50_L1_H1024.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky"><span style="font-weight:400;font-style:normal">width_ablations/bicaptioning_R_50_L1_H2048.yaml</span></td>
        <td class="tg-c3ow">88.7</td>
        <td class="tg-c3ow">53.8</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/mbeeso8wyieq8wy/bicaptioning_R_50_L1_H2048.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    </tbody>
    </table></div>


Depth Ablations
^^^^^^^^^^^^^^^

.. raw:: html

    <div class="tg-wrap"><table class="tg">
    <tbody>
    <tr>
        <td class="tg-zlqz">Model Config Name</td>
        <td class="tg-zlqz">VOC07<br>mAP</td>
        <td class="tg-zlqz">ImageNet<br>Top-1 Acc.</td>
        <td class="tg-zlqz">Model URL</td>
    </tr>
    <tr>
        <td class="tg-0pky">depth_ablations/bicaptioning_R_50_L1_H1024.yaml</td>
        <td class="tg-c3ow">88.3</td>
        <td class="tg-c3ow">53.2</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/pdat4tvhnqxel64/bicaptioning_R_50_L1_H1024.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky">depth_ablations/bicaptioning_R_50_L2_H1024.yaml</td>
        <td class="tg-c3ow">88.8</td>
        <td class="tg-c3ow">53.8</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/ft1vtt4okirzjgo/bicaptioning_R_50_L2_H1024.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky"><span style="font-weight:400;font-style:normal">depth_ablations/bicaptioning_R_50_L3_H1024.yaml</span></td>
        <td class="tg-c3ow">88.7</td>
        <td class="tg-c3ow">53.9</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/5ldo1rcsnrshmjr/bicaptioning_R_50_L3_H1024.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky"><span style="font-weight:400;font-style:normal">depth_ablations/bicaptioning_R_50_L4_H1024.yaml</span></td>
        <td class="tg-c3ow">88.7</td>
        <td class="tg-c3ow">53.9</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/zgiit2wcluuq3xh/bicaptioning_R_50_L4_H1024.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    </tbody>
    </table></div>


Backbone Ablations
^^^^^^^^^^^^^^^^^^

.. raw:: html

    <div class="tg-wrap"><table class="tg">
    <tbody>
    <tr>
        <td class="tg-zlqz">Model Config Name</td>
        <td class="tg-zlqz">VOC07<br>mAP</td>
        <td class="tg-zlqz">ImageNet<br>Top-1 Acc.</td>
        <td class="tg-zlqz">Model URL</td>
    </tr>
    <tr>
        <td class="tg-0pky">backbone_ablations/bicaptioning_R_50_L1_H1024.yaml</td>
        <td class="tg-c3ow">88.3</td>
        <td class="tg-c3ow">53.2</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/pdat4tvhnqxel64/bicaptioning_R_50_L1_H1024.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky">backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml</td>
        <td class="tg-c3ow">88.5</td>
        <td class="tg-c3ow">52.9</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/5o198ux709r6376/bicaptioning_R_50W2X_L1_H1024.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    <tr>
        <td class="tg-0pky">backbone_ablations/bicaptioning_R_101_L1_H1024.yaml</td>
        <td class="tg-c3ow">88.7</td>
        <td class="tg-c3ow">52.1</td>
        <td class="tg-c3ow"><a href="https://www.dropbox.com/s/bb74jubt68cpn80/bicaptioning_R_101_L1_H1024.pth?dl=0" target="_blank" rel="noopener noreferrer">model</a></td>
    </tr>
    </tbody>
    </table></div>


================================================
FILE: docs/virtex/usage/pretrain.rst
================================================
How to train your VirTex model?
===============================

We provide training scripts for all type of VirTex models from the paper;
including our best-performing model and other ablations.
Our training jobs are specified by config files (YAML).
Execute all commands from project root to use the provided config files.


Training the base VirTex model
------------------------------

Train the base VirTex model with ResNet-50 visual backbone; and a textual head
with ``L = 1, H = 1024`` using all default optimization hyperparameters.

.. code-block::

    python scripts/pretrain_virtex.py \
        --config configs/_base_bicaptioning_R_50_L1_H1024.yaml \
        --num-gpus-per-machine 8 \
        --cpu-workers 4 \
        --serialization-dir /tmp/VIRTEX_R_50_L1_H1024
        # Default: --checkpoint-every 2000 --log-every 20

Training job will save checkpoints, tensorboard logs (loss curves and metrics),
and back up the config in ``--serialization-dir``. Use ``tensorboard --logdir
<serialization_dir>`` to view training curves, validation metrics etc. directly
on tensorboard.

We recommend training with 8 GPUs on the same machine, although training with
multiple GPUs across machines (see: ``--num-machines`` and ``--machine-rank``),
single GPU (``--num-gpus-per-machine 1``) as well as CPU
(``--num-gpus-per-machine 0``) is also supported. Using multiple GPUs for
interactive debugging with PDB is not supported, as PDB and ``multiprocessing``
module do not play nice.

-------------------------------------------------------------------------------

Reproducing all VirTex ablations
--------------------------------

To reproduce all ablations from the `paper <https://arxiv.org/abs/2006.06666>`_,
replace the ``--config`` argument in above command with the following (all
assumed to be relative to project root):

Pretraining Task Ablations
^^^^^^^^^^^^^^^^^^^^^^^^^^

1. **Bicaptioning:** configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml
2. **Forward Captioning:** configs/task_ablations/captioning_R_50_L1_H2048.yaml
3. **Token Classification:** configs/task_ablations/token_classification_R_50.yaml
4. **Multilabel Classification:** configs/task_ablations/multilabel_classification_R_50.yaml
5. **Masked Language Modeling:** configs/task_ablations/masked_lm_R_50_L1_H2048.yaml

Transformer Size Ablations
^^^^^^^^^^^^^^^^^^^^^^^^^^

1. **Width (H = 512):** configs/width_ablations/bicaptioning_R_50_L1_H512.yaml
2. **Width (H = 768):** configs/width_ablations/bicaptioning_R_50_L1_H768.yaml
3. **Width (H = 1024):** configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml
4. **Width (H = 2048):** configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml
5. **Depth (L = 1):** configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml
6. **Depth (L = 2):** configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml
7. **Depth (L = 3):** configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml
8. **Depth (L = 4):** configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml

Backbone Ablations
^^^^^^^^^^^^^^^^^^

1. **ResNet-50:** configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml
2. **ResNet-50 w2x:** configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml
3. **ResNet-101:** configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml

.. note::

    **Pretraining Task Ablations** (1), **Transformer Size Ablations** (3 and 5)
    and **Backbone Ablations** (1) are all the same exact model.


================================================
FILE: docs/virtex/usage/setup_dependencies.rst
================================================
How to setup this codebase?
===========================

.. raw:: html

    <hr>

This codebase requires Python 3.6+ or higher. We recommend using Anaconda or
Miniconda. We walk through installation and data preprocessing here.


Install Dependencies
--------------------

For these steps to install through Anaconda (or Miniconda).

1. Install Anaconda or Miniconda distribution based on Python 3+ from their
   `downloads site <https://conda.io/docs/user-guide/install/download.html>`_.


2. Clone the repository first.

    .. code-block:: shell

        git clone https://www.github.com/kdexd/virtex


3. Create a conda environment and install all the dependencies.

    .. code-block:: shell

        cd virtex
        conda create -n virtex python=3.8
        conda activate virtex
        pip install -r requirements.txt


4. Install additional packages from Github.

    .. code-block:: shell

        pip install git+git://github.com/facebookresearch/fvcore.git#egg=fvcore
        pip install git+git://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI


5. Install this codebase as a package in development version.

    .. code-block:: shell

        python setup.py develop

Now you can ``import virtex`` from anywhere as long as you have this conda
environment activated.

-------------------------------------------------------------------------------


Setup Datasets
--------------

Datasets are assumed to exist in ``./datasets`` directory (relative to the
project root) following the structure specified below. COCO is used for
pretraining, and rest of the datasets (including COCO) are used for downstream
tasks. This structure is compatible when using
`Detectron2 <https://github.com/facebookresearch/detectron2>`_ for downstream
tasks.

COCO
^^^^
.. code-block::

    datasets/coco/
        annotations/
            captions_{train,val}2017.json
            instances_{train,val}2017.json
        train2017/
            # images in train2017 split
        val2017/
            # images in val2017 split

LVIS
^^^^
.. code-block::

    datasets/coco/
        train2017/
        val2017/
    datasets/lvis/
        lvis_v1.0_{train,val}.json

PASCAL VOC
^^^^^^^^^^
.. code-block::

    datasets/VOC2007/
        Annotations/
        ImageSets/
            Main/
                trainval.txt
                test.txt
        JPEGImages/

    datasets/VOC2012/
        # Same as VOC2007 above

ImageNet
^^^^^^^^
.. code-block::

    datasets/imagenet/
        train/
            # One directory per category with images in it
        val/
            # One directory per category with images in it
        ILSVRC2012_devkit_t12.tar.gz

iNaturalist 2018
^^^^^^^^^^^^^^^^
.. code-block::

    datasets/inaturalist/
        train_val2018/
        annotations/
            train2018.json
            val2018.json

-------------------------------------------------------------------------------


Build vocabulary
----------------

Build a vocabulary out of COCO Captions ``train2017`` split.

    .. code-block:: shell

        python scripts/build_vocabulary.py \
            --captions datasets/coco/annotations/captions_train2017.json \
            --vocab-size 10000 \
            --output-prefix datasets/vocab/coco_10k \
            --do-lower-case

That's it! You are all set to use this codebase.


================================================
FILE: docs/virtex/utils.beam_search.rst
================================================
virtex.utils.beam_search
========================

.. raw:: html

    <hr>

.. automodule:: virtex.utils.beam_search


================================================
FILE: docs/virtex/utils.checkpointing.rst
================================================
virtex.utils.checkpointing
==========================

.. raw:: html

    <hr>

.. automodule:: virtex.utils.checkpointing


================================================
FILE: docs/virtex/utils.common.rst
================================================
virtex.utils.common
===================

.. raw:: html

    <hr>

.. automodule:: virtex.utils.common


================================================
FILE: docs/virtex/utils.distributed.rst
================================================
virtex.utils.distributed
========================

.. raw:: html

    <hr>

.. automodule:: virtex.utils.distributed


================================================
FILE: docs/virtex/utils.metrics.rst
================================================
virtex.utils.metrics
====================

.. raw:: html

    <hr>

.. automodule:: virtex.utils.metrics


================================================
FILE: docs/virtex/utils.rst
================================================
virtex.utils
============

.. raw:: html

    <hr>

.. toctree::

    utils.common
    utils.distributed
    utils.timer
    utils.checkpointing
    utils.beam_search
    utils.metrics


================================================
FILE: docs/virtex/utils.timer.rst
================================================
virtex.utils.timer
==================

.. raw:: html

    <hr>

.. automodule:: virtex.utils.timer


================================================
FILE: hubconf.py
================================================
dependencies = ["torch"]

import torch
import torchvision


R50_URL = "https://www.dropbox.com/s/pxgjxcva7oypf12/backbone_bicaptioning_R_50_L1_H2048.pth?dl=1"


def resnet50(pretrained: bool = False, **kwargs):
    r"""
    ResNet-50 visual backbone from the best performing VirTex model: pretrained
    for bicaptioning on COCO Captions, with textual head ``L = 1, H = 2048``.

    This is a torchvision-like model, with the last ``avgpool`` and `fc``
    modules replaced with ``nn.Identity()`` modules. Given a batch of image
    tensors with size ``(B, 3, 224, 224)``, this model computes spatial image
    features of size ``(B, 7, 7, 2048)``, where B = batch size.

    pretrained (bool): Whether to load model with pretrained weights.
    """

    # Create a torchvision resnet50 with randomly initialized weights.
    model = torchvision.models.resnet50(pretrained=False, **kwargs)

    # Replace global average pooling and fully connected layers with identity
    # modules.
    model.avgpool = torch.nn.Identity()
    model.fc = torch.nn.Identity()

    if pretrained:
        model.load_state_dict(
            torch.hub.load_state_dict_from_url(R50_URL, progress=False)
        )
    return model


================================================
FILE: requirements.txt
================================================
albumentations>=1.0
Cython>=0.25
future==0.18.0
loguru>=0.3
lvis>=0.5
numpy>=1.17
opencv-python>=4.2.0
scikit-learn>=1.0
sentencepiece>=0.1.90
torch>=1.9
torchvision>=0.10
tqdm>=4.50.0


================================================
FILE: scripts/build_vocabulary.py
================================================
import argparse
import json
import os
import tempfile
import unicodedata
from typing import List

import sentencepiece as sp


# fmt: off
parser = argparse.ArgumentParser(
    description="""Build a vocabulary out of captions corpus. This vocabulary
    would be a file which our tokenizer can understand.
    """
)
parser.add_argument(
    "-c", "--captions", default="datasets/coco/annotations/captions_train2017.json",
    help="Path to caption annotations file in COCO format.",
)
parser.add_argument(
    "-s", "--vocab-size", type=int, default=10000,
    help="Total desired size of our vocabulary.",
)
parser.add_argument(
    "-o", "--output-prefix", default="datasets/vocab/coco_10k",
    help="Prefix of the files to be saved. Two files will be saved: "
    "[prefix].model and [prefix].vocab",
)
parser.add_argument(
    "-l", "--do-lower-case", action="store_true",
    help="Whether to lower case the captions before forming vocabulary.",
)
parser.add_argument(
    "-a", "--keep-accents", action="store_true",
    help="Whether to keep accents before forming vocabulary (dropped by default).",
)
# fmt: on


def _read_captions(annotations_path: str) -> List[str]:
    r"""
    Given a path to annotation file, read it and return a list of captions.
    These are not processed by any means, returned from the file as-is.

    Args:
        annotations_path: Path to an annotations file containing captions.

    Returns:
        List of captions from this annotation file.
    """

    _annotations = json.load(open(annotations_path))

    captions: List[str] = []
    for ann in _annotations["annotations"]:
        captions.append(ann["caption"])

    return captions


if __name__ == "__main__":
    _A = parser.parse_args()
    captions: List[str] = _read_captions(_A.captions)

    # Lower case the captions and remove accents according to arguments.
    for i, caption in enumerate(captions):
        caption = caption.lower() if _A.do_lower_case else caption

        if not _A.keep_accents:
            caption = unicodedata.normalize("NFKD", caption)
            caption = "".join(
                [chr for chr in caption if not unicodedata.combining(chr)]
            )

        captions[i] = caption

    # Create a temporary directory and dump the captions corpus as a text file
    # with one caption per line. That's how sentencepiece wants its input.
    tmpdir_path = tempfile.mkdtemp()

    with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file:
        for caption in captions:
            captions_file.write(caption + "\n")

    # Padding/out-of-vocab token will be "<unk>" and ID 0 by default.
    # Add [SOS],[EOS] and [MASK] tokens. [MASK] will not be used during
    # captioning, but good to have to reuse vocabulary across pretext tasks.
    sp.SentencePieceTrainer.train(
        f" --input={os.path.join(tmpdir_path, 'captions.txt')}"
        f" --vocab_size={_A.vocab_size}"
        f" --model_prefix={_A.output_prefix}"
        " --model_type=bpe --character_coverage=1.0"
        " --bos_id=-1 --eos_id=-1"
        " --control_symbols=[SOS],[EOS],[MASK]"
    )


================================================
FILE: scripts/clf_linear.py
================================================
import argparse
import os

from loguru import logger
import torch
from torch import nn
from torch.cuda import amp
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter

from virtex.config import Config
from virtex.factories import (
    DownstreamDatasetFactory,
    PretrainingModelFactory,
    OptimizerFactory,
    LRSchedulerFactory,
)
from virtex.utils.checkpointing import CheckpointManager
from virtex.utils.common import common_parser, common_setup, cycle
import virtex.utils.distributed as dist
from virtex.utils.metrics import TopkAccuracy
from virtex.utils.timer import Timer


# fmt: off
parser = common_parser(
    description="""Do image classification with linear models and frozen
    feature extractor, or fine-tune the feature extractor end-to-end."""
)
group = parser.add_argument_group("Downstream config arguments.")
group.add_argument(
    "--down-config", metavar="FILE", help="Path to a downstream config file."
)
group.add_argument(
    "--down-config-override", nargs="*", default=[],
    help="A list of key-value pairs to modify downstream config params.",
)

parser.add_argument_group("Checkpointing and Logging")
parser.add_argument(
    "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
    default="virtex", help="""How to initialize weights:
        1. 'random' initializes all weights randomly
        2. 'imagenet' initializes backbone weights from torchvision model zoo
        3. {'torchvision', 'virtex'} load state dict from --checkpoint-path
            - with 'torchvision', state dict would be from PyTorch's training
              script.
            - with 'virtex' it should be for our full pretrained model."""
)
parser.add_argument(
    "--log-every", type=int, default=50,
    help="""Log training curves to tensorboard after every these many iterations
    only master process logs averaged loss values across processes.""",
)
parser.add_argument(
    "--checkpoint-path",
    help="""Path to load checkpoint and run downstream task evaluation. The
    name of checkpoint file is required to be `model_*.pth`, where * is
    iteration number from which the checkpoint was serialized."""
)
parser.add_argument(
    "--checkpoint-every", type=int, default=5000,
    help="""Serialize model to a checkpoint after every these many iterations.
    For ImageNet, (5005 iterations = 1 epoch); for iNaturalist (1710 iterations
    = 1 epoch).""",
)
# fmt: on


def main(_A: argparse.Namespace):

    if _A.num_gpus_per_machine == 0:
        # Set device as CPU if num_gpus_per_machine = 0.
        device = torch.device("cpu")
    else:
        # Get the current device as set for current distributed process.
        # Check `launch` function in `virtex.utils.distributed` module.
        device = torch.cuda.current_device()

    # Create a downstream config object (this will be immutable) and perform
    # common setup such as logging and setting up serialization directory.
    _DOWNC = Config(_A.down_config, _A.down_config_override)
    common_setup(_DOWNC, _A, job_type="downstream")

    # Create a (pretraining) config object and backup in serializaion directory.
    _C = Config(_A.config, _A.config_override)
    _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml"))

    # Get dataset name for tensorboard logging.
    DATASET = _DOWNC.DATA.ROOT.split("/")[-1]

    # Set number of output classes according to dataset:
    NUM_CLASSES_MAPPING = {"imagenet": 1000, "inaturalist": 8142}
    NUM_CLASSES = NUM_CLASSES_MAPPING[DATASET]

    # -------------------------------------------------------------------------
    #   INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER
    # -------------------------------------------------------------------------
    train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="train")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(),
        num_workers=_A.cpu_workers,
        sampler=DistributedSampler(
            train_dataset,
            num_replicas=dist.get_world_size(),
            rank=dist.get_rank(),
            shuffle=True,
        ),
        drop_last=False,
        pin_memory=True,
        collate_fn=train_dataset.collate_fn,
    )
    val_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="val")
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(),
        num_workers=_A.cpu_workers,
        sampler=DistributedSampler(
            val_dataset,
            num_replicas=dist.get_world_size(),
            rank=dist.get_rank(),
            shuffle=False,
        ),
        pin_memory=True,
        drop_last=False,
        collate_fn=val_dataset.collate_fn,
    )
    # Initialize model using pretraining config.
    pretrained_model = PretrainingModelFactory.from_config(_C)

    # Load weights according to the init method, do nothing for `random`, and
    # `imagenet` is already taken care of.
    if _A.weight_init == "virtex":
        CheckpointManager(model=pretrained_model).load(_A.checkpoint_path)
    elif _A.weight_init == "torchvision":
        # Keep strict=False because this state dict may have weights for
        # last fc layer.
        pretrained_model.visual.cnn.load_state_dict(
            torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
            strict=False,
        )

    # Pull out the CNN (torchvision-like) from our pretrained model and add
    # back the FC layer - this is exists in torchvision models, and is set to
    # `nn.Identity()` during pretraining.
    model = pretrained_model.visual.cnn  # type: ignore
    model.fc = nn.Linear(_DOWNC.MODEL.VISUAL.FEATURE_SIZE, NUM_CLASSES).to(device)
    model = model.to(device)

    # Re-initialize the FC layer.
    torch.nn.init.normal_(model.fc.weight.data, mean=0.0, std=0.01)
    torch.nn.init.constant_(model.fc.bias.data, 0.0)

    # Freeze all layers except FC as per config param.
    if _DOWNC.MODEL.VISUAL.FROZEN:
        # Set model to eval mode to prevent BatchNorm from updating running
        # mean and std. With only a linear layer, being in eval mode when
        # training will not matter anyway.
        model.eval()

        for name, param in model.named_parameters():
            if "fc" not in name:
                param.requires_grad = False

    # Cross entropy loss and accuracy meter.
    criterion = nn.CrossEntropyLoss()
    top1 = TopkAccuracy(k=1)

    optimizer = OptimizerFactory.from_config(_DOWNC, model.named_parameters())
    scheduler = LRSchedulerFactory.from_config(_DOWNC, optimizer)
    del pretrained_model

    # -------------------------------------------------------------------------
    #  BEFORE TRAINING STARTS
    # -------------------------------------------------------------------------

    # Create a gradient scaler for automatic mixed precision.
    scaler = amp.GradScaler(enabled=_DOWNC.AMP)

    # Create an iterator from dataloader to sample batches perpetually.
    train_dataloader_iter = cycle(train_dataloader, device)

    if dist.get_world_size() > 1:
        dist.synchronize()
        model = nn.parallel.DistributedDataParallel(
            model, device_ids=[device], find_unused_parameters=True
        )

    if dist.is_master_process():
        checkpoint_manager = CheckpointManager(
            _A.serialization_dir,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
        )
        tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)

    # Keep track of time per iteration and ETA.
    timer = Timer(start_from=1, total_iterations=_DOWNC.OPTIM.NUM_ITERATIONS)

    # -------------------------------------------------------------------------
    #   TRAINING LOOP
    # -------------------------------------------------------------------------
    for iteration in range(1, _DOWNC.OPTIM.NUM_ITERATIONS + 1):
        timer.tic()
        optimizer.zero_grad()
        batch = next(train_dataloader_iter)

        with amp.autocast(enabled=_DOWNC.AMP):
            logits = model(batch["image"])
            loss = criterion(logits, batch["label"])

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()
        timer.toc()

        if iteration % _A.log_every == 0 and dist.is_master_process():
            logger.info(
                f"{timer.stats} | Loss: {loss:.3f} | GPU: {dist.gpu_mem_usage()} MB"
            )
            tensorboard_writer.add_scalar(f"{DATASET}/train_loss", loss, iteration)
            tensorboard_writer.add_scalar(
                f"{DATASET}/learning_rate",
                optimizer.param_groups[0]["lr"],
                iteration,
            )

        # ---------------------------------------------------------------------
        #   VALIDATION
        # ---------------------------------------------------------------------
        if iteration % _A.checkpoint_every == 0:
            torch.set_grad_enabled(False)
            model.eval()

            total_val_loss = torch.tensor(0.0).to(device)

            for val_iteration, batch in enumerate(val_dataloader, start=1):
                for key in batch:
                    batch[key] = batch[key].to(device)

                logits = model(batch["image"])
                loss = criterion(logits, batch["label"])
                _ = top1(logits, batch["label"])
                total_val_loss += loss

            # Divide each loss component by number of val batches per GPU.
            total_val_loss = total_val_loss / val_iteration
            dist.average_across_processes(total_val_loss)

            # Get accumulated Top-1 accuracy for logging across GPUs.
            acc = top1.get_result()
            top1.reset()
            dist.average_across_processes(acc)

            torch.set_grad_enabled(True)

            # Set model back to train mode only when fine-tuning end-to-end.
            if not _DOWNC.MODEL.VISUAL.FROZEN:
                model.train()

            # Save recent checkpoint and best checkpoint based on accuracy.
            if dist.is_master_process():
                checkpoint_manager.step(iteration)

                logger.info(f"Iter: {iteration} | Top-1 accuracy: {acc})")
                tensorboard_writer.add_scalar(
                    f"{DATASET}/val_loss", total_val_loss, iteration
                )
                # This name scoping will result in Tensorboard displaying all
                # metrics (VOC07, caption, etc.) together.
                tensorboard_writer.add_scalars(
                    f"metrics/{DATASET}", {"top1": acc}, iteration
                )

        # All processes will wait till master process is done logging.
        dist.synchronize()


if __name__ == "__main__":
    _A = parser.parse_args()

    # Add an arg in config override if `--weight-init` is imagenet.
    if _A.weight_init == "imagenet":
        _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])

    if _A.num_gpus_per_machine == 0:
        main(_A)
    else:
        # This will launch `main` and set appropriate CUDA device (GPU ID) as
        # per process (accessed in the beginning of `main`).
        dist.launch(
            main,
            num_machines=_A.num_machines,
            num_gpus_per_machine=_A.num_gpus_per_machine,
            machine_rank=_A.machine_rank,
            dist_url=_A.dist_url,
            args=(_A,),
        )


================================================
FILE: scripts/clf_voc07.py
================================================
import argparse
import multiprocessing as mp
import os
from typing import Any, List

import numpy as np
import torch
from loguru import logger
from sklearn.svm import LinearSVC
from sklearn.metrics import average_precision_score
from sklearn.model_selection import cross_val_score
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from virtex.config import Config
from virtex.factories import PretrainingModelFactory, DownstreamDatasetFactory
from virtex.utils.checkpointing import CheckpointManager
from virtex.utils.common import common_parser, common_setup


parser = common_parser(
    description="Train SVMs for VOC2007 classification on a pretrained model."
)
group = parser.add_argument_group("Downstream config arguments.")
group.add_argument(
    "--down-config", metavar="FILE", help="Path to a downstream config file."
)
group.add_argument(
    "--down-config-override",
    nargs="*",
    default=[],
    help="A list of key-value pairs to modify downstream config params.",
)

# fmt: off
parser.add_argument_group("Checkpointing")
parser.add_argument(
    "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
    default="virtex", help="""How to initialize weights:
        1. 'random' initializes all weights randomly
        2. 'imagenet' initializes backbone weights from torchvision model zoo
        3. {'torchvision', 'virtex'} load state dict from --checkpoint-path
            - with 'torchvision', state dict would be from PyTorch's training
              script.
            - with 'virtex' it should be for our full pretrained model."""
)
parser.add_argument(
    "--checkpoint-path",
    help="Path to load checkpoint and run downstream task evaluation."
)
# fmt: on


def train_test_single_svm(args):

    feats_train, tgts_train, feats_test, tgts_test, cls_name = args
    SVM_COSTS = [0.01, 0.1, 1.0, 10.0]

    cls_labels = np.copy(tgts_train)
    # Meaning of labels in VOC/COCO original loaded target files:
    # label 0 = not present, set it to -1 as svm train target
    # label 1 = present. Make the svm train target labels as -1, 1.
    cls_labels[np.where(cls_labels == 0)] = -1

    # See which cost maximizes the AP for this class.
    best_crossval_ap: float = 0.0
    best_crossval_clf = None
    best_cost: float = 0.0

    # fmt: off
    for cost in SVM_COSTS:
        clf = LinearSVC(
            C=cost, class_weight={1: 2, -1: 1}, penalty="l2",
            loss="squared_hinge", max_iter=2000,
        )
        ap_scores = cross_val_score(
            clf, feats_train, cls_labels, cv=3, scoring="average_precision",
        )
        clf.fit(feats_train, cls_labels)

        # Keep track of best SVM (based on cost) for each class.
        if ap_scores.mean() > best_crossval_ap:
            best_crossval_ap = ap_scores.mean()
            best_crossval_clf = clf
            best_cost = cost

    logger.info(f"Best SVM {cls_name}: cost {best_cost}, mAP {best_crossval_ap * 100}")
    # fmt: on

    # -------------------------------------------------------------------------
    #   TEST THE TRAINED SVM (PER CLASS)
    # -------------------------------------------------------------------------
    predictions = best_crossval_clf.decision_function(feats_test)
    evaluate_data_inds = tgts_test != -1
    eval_preds = predictions[evaluate_data_inds]

    cls_labels = np.copy(tgts_test)
    eval_cls_labels = cls_labels[evaluate_data_inds]
    eval_cls_labels[np.where(eval_cls_labels == 0)] = -1

    # Binarize class labels to make AP targets.
    targets = eval_cls_labels > 0
    return average_precision_score(targets, eval_preds)


def main(_A: argparse.Namespace):

    if _A.num_gpus_per_machine == 0:
        # Set device as CPU if num_gpus_per_machine = 0.
        device = torch.device("cpu")
    else:
        # Get the current device (this will be zero here by default).
        device = torch.cuda.current_device()

    # Create a downstream config object (this will be immutable) and perform
    # common setup such as logging and setting up serialization directory.
    _DOWNC = Config(_A.down_config, _A.down_config_override)
    common_setup(_DOWNC, _A, job_type="downstream")

    # Create a (pretraining) config object and backup in serialization directory.
    _C = Config(_A.config, _A.config_override)
    _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml"))

    # -------------------------------------------------------------------------
    #   INSTANTIATE DATALOADER, MODEL, AND FEATURE EXTRACTOR
    # -------------------------------------------------------------------------

    train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="trainval")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=_DOWNC.OPTIM.BATCH_SIZE,
        num_workers=_A.cpu_workers,
        pin_memory=True,
    )
    test_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="test")
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=_DOWNC.OPTIM.BATCH_SIZE,
        num_workers=_A.cpu_workers,
        pin_memory=True,
    )
    NUM_CLASSES = len(train_dataset.class_names)

    # Initialize from a checkpoint, but only keep the visual module.
    model = PretrainingModelFactory.from_config(_C)

    # Load weights according to the init method, do nothing for `random`, and
    # `imagenet` is already taken care of.
    if _A.weight_init == "virtex":
        ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path)
    elif _A.weight_init == "torchvision":
        # Keep strict=False because this state dict may have weights for
        # last fc layer.
        model.visual.cnn.load_state_dict(
            torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
            strict=False,
        )
        # Set ``ITERATION`` to a dummy value.
        ITERATION = 0

    # Transfer model to GPU and set to eval mode. This is a torchvision model
    # and it returns features as ``(batch_size, 2048, 7, 7)``.
    model = model.visual.cnn.to(device).eval()

    # -------------------------------------------------------------------------
    #   EXTRACT FEATURES FOR TRAINING SVMs
    # -------------------------------------------------------------------------

    features_train: List[torch.Tensor] = []
    targets_train: List[torch.Tensor] = []

    features_test: List[torch.Tensor] = []
    targets_test: List[torch.Tensor] = []

    # VOC07 is small, extract all features and keep them in memory.
    with torch.no_grad():
        for batch in tqdm(train_dataloader, desc="Extracting train features:"):
            features = model(batch["image"].to(device))

            # Global average pool features. Assume the tensor is in NCHW format.
            if len(features.size()) > 2:
                # shape: (batch_size, visual_feature_size)
                features = features.mean(dim=(2, 3))

            # L2-normalize the global average pooled features.
            features = F.normalize(features, dim=-1)

            features_train.append(features.cpu())
            targets_train.append(batch["label"])

        # Similarly extract test features.
        for batch in tqdm(test_dataloader, desc="Extracting test features:"):
            features = model(batch["image"].to(device))

            if len(features.size()) > 2:
                features = features.mean(dim=(2, 3))

            features = F.normalize(features, dim=-1)

            features_test.append(features.cpu())
            targets_test.append(batch["label"])

    # Convert batches of features/targets to one large numpy array
    features_train = torch.cat(features_train, dim=0).numpy()
    targets_train = torch.cat(targets_train, dim=0).numpy().astype(np.int32)

    features_test = torch.cat(features_test, dim=0).numpy()
    targets_test = torch.cat(targets_test, dim=0).numpy().astype(np.int32)

    # -------------------------------------------------------------------------
    #   TRAIN AND TEST SVMs WITH EXTRACTED FEATURES
    # -------------------------------------------------------------------------

    input_args: List[Any] = []

    # Iterate over all VOC07 classes and train one-vs-all linear SVMs.
    for cls_idx in range(NUM_CLASSES):
        # fmt: off
        input_args.append((
            features_train, targets_train[:, cls_idx],
            features_test, targets_test[:, cls_idx],
            train_dataset.class_names[cls_idx],
        ))
        # fmt: on

    pool = mp.Pool(processes=_A.cpu_workers)
    pool_output = pool.map(train_test_single_svm, input_args)

    # -------------------------------------------------------------------------
    #   TENSORBOARD LOGGING (RELEVANT MAINLY FOR weight_init=checkpoint)
    # -------------------------------------------------------------------------

    # Tensorboard writer for logging mAP scores. This is useful especially
    # when weight_init=checkpoint (which maybe be coming from a training job).
    tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)

    # Test set mAP for each class, for features from every layer.
    test_map = torch.tensor(pool_output).mean()
    logger.info(f"Iteration: {ITERATION}, mAP: {test_map * 100}")
    tensorboard_writer.add_scalars(
        "metrics/voc07_clf", {f"voc07_mAP": test_map * 100}, ITERATION
    )


if __name__ == "__main__":
    _A = parser.parse_args()

    if _A.num_gpus_per_machine > 1:
        raise ValueError("Using multiple GPUs is not supported for this script.")

    # Add an arg in config override if `--weight-init` is imagenet.
    if _A.weight_init == "imagenet":
        _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])

    # No distributed training here, just a single process.
    main(_A)


================================================
FILE: scripts/eval_captioning.py
================================================
import argparse
import json
import os
from typing import Any, Dict, List

from loguru import logger
import torch
from torch.utils.data import DataLoader

from virtex.config import Config
from virtex.data import ImageDirectoryDataset
from virtex.factories import TokenizerFactory, PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager
from virtex.utils.common import common_parser
from virtex.utils.metrics import CocoCaptionsEvaluator


# fmt: off
parser = common_parser(
    description="""Run image captioning inference on a pretrained model, and/or
    evaluate pretrained model on COCO Captions val2017 split."""
)
parser.add_argument(
    "--images", "--data-root", default=None,
    help="""Path to a directory containing image files to generate captions for.
    Default: COCO val2017 image directory as expected relative to project root."""
)
parser.add_argument(
    "--checkpoint-path", required=True,
    help="Path to load checkpoint and run captioning evaluation."
)
parser.add_argument(
    "--output", default=None,
    help="Path to save predictions as a JSON file."
)
parser.add_argument(
    "--calc-metrics", action="store_true",
    help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions.
    This flag should not be set when running inference on arbitrary images."""
)
# fmt: on


def main(_A: argparse.Namespace):

    if _A.num_gpus_per_machine == 0:
        # Set device as CPU if num_gpus_per_machine = 0.
        device = torch.device("cpu")
    else:
        # Get the current device (this will be zero here by default).
        device = torch.cuda.current_device()

    _C = Config(_A.config, _A.config_override)

    tokenizer = TokenizerFactory.from_config(_C)

    if _A.data_root is None:
        _A.data_root = os.path.join(_C.DATA.ROOT, "val2017")

    val_dataloader = DataLoader(
        ImageDirectoryDataset(_A.data_root),
        batch_size=_C.OPTIM.BATCH_SIZE,
        num_workers=_A.cpu_workers,
        pin_memory=True,
    )
    # Initialize model from a checkpoint.
    model = PretrainingModelFactory.from_config(_C).to(device)
    ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path)
    model.eval()

    # Make a list of predictions to evaluate.
    predictions: List[Dict[str, Any]] = []

    for val_iteration, val_batch in enumerate(val_dataloader, start=1):

        val_batch["image"] = val_batch["image"].to(device)
        with torch.no_grad():
            output_dict = model(val_batch)

        # Make a dictionary of predictions in COCO format.
        for image_id, caption in zip(
            val_batch["image_id"], output_dict["predictions"]
        ):
            predictions.append(
                {
                    # Convert image id to int if possible (mainly for COCO eval).
                    "image_id": int(image_id) if image_id.isdigit() else image_id,
                    "caption": tokenizer.decode(caption.tolist()),
                }
            )

    logger.info("Displaying first 25 caption predictions:")
    for pred in predictions[:25]:
        logger.info(f"{pred['image_id']} :: {pred['caption']}")

    # Save predictions as a JSON file if specified.
    if _A.output is not None:
        os.makedirs(os.path.dirname(_A.output), exist_ok=True)
        json.dump(predictions, open(_A.output, "w"))
        logger.info(f"Saved predictions to {_A.output}")

    # Calculate CIDEr and SPICE metrics using ground truth COCO Captions. This
    # should be skipped when running inference on arbitrary images.
    if _A.calc_metrics:
        # Assume ground truth (COCO val2017 annotations) exist.
        gt = os.path.join(_C.DATA.ROOT, "annotations", "captions_val2017.json")

        metrics = CocoCaptionsEvaluator(gt).evaluate(predictions)
        logger.info(f"Iter: {ITERATION} | Metrics: {metrics}")


if __name__ == "__main__":
    _A = parser.parse_args()
    if _A.num_gpus_per_machine > 1:
        raise ValueError("Using multiple GPUs is not supported for this script.")

    # No distributed training here, just a single process.
    main(_A)


================================================
FILE: scripts/eval_detectron2.py
================================================
"""
Finetune a pre-trained model on a downstream task, one of those available in
Detectron2.
Supported downstream:
  - LVIS Instance Segmentation
  - COCO Instance Segmentation
  - Pascal VOC 2007+12 Object Detection

Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py
Thanks to the developers of Detectron2!
"""
import argparse
import os
import re

import torch
from torch.utils.tensorboard import SummaryWriter

import detectron2 as d2
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultTrainer, default_setup
from detectron2.evaluation import (
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    COCOEvaluator,
)
from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads

from virtex.config import Config
from virtex.factories import PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager
from virtex.utils.common import common_parser
import virtex.utils.distributed as dist

# fmt: off
parser = common_parser(
    description="Train object detectors from pretrained visual backbone."
)
parser.add_argument(
    "--d2-config", required=True,
    help="Path to a detectron2 config for downstream task finetuning."
)
parser.add_argument(
    "--d2-config-override", nargs="*", default=[],
    help="""Key-value pairs from Detectron2 config to override from file.
    Some keys will be ignored because they are set from other args:
    [DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD,
    TEST.EVAL_PERIOD, OUTPUT_DIR]""",
)

parser.add_argument_group("Checkpointing and Logging")
parser.add_argument(
    "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"],
    default="virtex", help="""How to initialize weights:
        1. 'random' initializes all weights randomly
        2. 'imagenet' initializes backbone weights from torchvision model zoo
        3. {'torchvision', 'virtex'} load state dict from --checkpoint-path
            - with 'torchvision', state dict would be from PyTorch's training
              script.
            - with 'virtex' it should be for our full pretrained model."""
)
parser.add_argument(
    "--checkpoint-path",
    help="Path to load checkpoint and run downstream task evaluation."
)
parser.add_argument(
    "--resume", action="store_true", help="""Specify this flag when resuming
    training from a checkpoint saved by Detectron2."""
)
parser.add_argument(
    "--eval-only", action="store_true",
    help="Skip training and evaluate checkpoint provided at --checkpoint-path.",
)
parser.add_argument(
    "--checkpoint-every", type=int, default=5000,
    help="Serialize model to a checkpoint after every these many iterations.",
)
# fmt: on


@ROI_HEADS_REGISTRY.register()
class Res5ROIHeadsExtraNorm(Res5ROIHeads):
    r"""
    ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN
    C4/DC5 backbones for VOC detection.
    """

    def _build_res5_block(self, cfg):
        seq, out_channels = super()._build_res5_block(cfg)
        norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels)
        seq.add_module("norm", norm)
        return seq, out_channels


def build_detectron2_config(_C: Config, _A: argparse.Namespace):
    r"""Build detectron2 config based on our pre-training config and args."""
    _D2C = d2.config.get_cfg()

    # Override some default values based on our config file.
    _D2C.merge_from_file(_A.d2_config)
    _D2C.merge_from_list(_A.d2_config_override)

    # Set some config parameters from args.
    _D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers
    _D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every
    _D2C.OUTPUT_DIR = _A.serialization_dir

    # Set ResNet depth to override in Detectron2's config.
    _D2C.MODEL.RESNETS.DEPTH = int(
        re.search(r"resnet(\d+)", _C.MODEL.VISUAL.NAME).group(1)
        if "torchvision" in _C.MODEL.VISUAL.NAME
        else re.search(r"_R_(\d+)", _C.MODEL.VISUAL.NAME).group(1)
        if "detectron2" in _C.MODEL.VISUAL.NAME
        else 0
    )
    return _D2C


class DownstreamTrainer(DefaultTrainer):
    r"""
    Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks.

    Arguments:
        cfg (detectron2.config.CfgNode): Detectron2 config object.
        weights (Union[str, Dict]): Weights to load in the initialized model.
            If ``str``, then we assume path to a checkpoint, or if a ``dict``,
            we assume a state dict. This will be an ``str`` only if training
            is resumed from a Detectron2 checkpoint.
    """

    def __init__(self, cfg, weights):

        super().__init__(cfg)

        # Load pre-trained weights before wrapping to DDP because `ApexDDP` has
        # some weird issue with `DetectionCheckpointer`.
        # fmt: off
        if isinstance(weights, str):
            # weights are ``str`` means ImageNet init or resume training.
            self.start_iter = (
                DetectionCheckpointer(
                    self._trainer.model,
                    optimizer=self._trainer.optimizer,
                    scheduler=self.scheduler
                ).resume_or_load(weights, resume=True).get("iteration", -1) + 1
            )
        elif isinstance(weights, dict):
            # weights are a state dict means our pretrain init.
            DetectionCheckpointer(self._trainer.model)._load_model(weights)
        # fmt: on

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        evaluator_list = []
        evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type
        if evaluator_type == "pascal_voc":
            return PascalVOCDetectionEvaluator(dataset_name)
        elif evaluator_type == "coco":
            return COCOEvaluator(dataset_name, cfg, True, output_folder)
        elif evaluator_type == "lvis":
            return LVISEvaluator(dataset_name, cfg, True, output_folder)

    def test(self, cfg=None, model=None, evaluators=None):
        r"""Evaluate the model and log results to stdout and tensorboard."""
        cfg = cfg or self.cfg
        model = model or self.model

        tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
        results = super().test(cfg, model)
        flat_results = d2.evaluation.testing.flatten_results_dict(results)
        for k, v in flat_results.items():
            tensorboard_writer.add_scalar(k, v, self.start_iter)


def main(_A: argparse.Namespace):

    # Local process group is needed for detectron2.
    pg = list(range(dist.get_world_size()))
    d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg)

    # Create a config object (this will be immutable) and perform common setup
    # such as logging and setting up serialization directory.
    if _A.weight_init == "imagenet":
        _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True])
    _C = Config(_A.config, _A.config_override)

    # We use `default_setup` from detectron2 to do some common setup, such as
    # logging, setting up serialization etc. For more info, look into source.
    _D2C = build_detectron2_config(_C, _A)
    default_setup(_D2C, _A)

    # Prepare weights to pass in instantiation call of trainer.
    if _A.weight_init in {"virtex", "torchvision"}:
        if _A.resume:
            # If resuming training, let detectron2 load weights by providing path.
            model = None
            weights = _A.checkpoint_path
        else:
            # Load backbone weights from VirTex pretrained checkpoint.
            model = PretrainingModelFactory.from_config(_C)
            if _A.weight_init == "virtex":
                CheckpointManager(model=model).load(_A.checkpoint_path)
            else:
                model.visual.cnn.load_state_dict(
                    torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"],
                    strict=False,
                )
            weights = model.visual.detectron2_backbone_state_dict()
    else:
        # If random or imagenet init, just load weights after initializing model.
        model = PretrainingModelFactory.from_config(_C)
        weights = model.visual.detectron2_backbone_state_dict()

    # Back up pretrain config and model checkpoint (if provided).
    _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml"))
    if _A.weight_init == "virtex" and not _A.resume:
        torch.save(
            model.state_dict(),
            os.path.join(_A.serialization_dir, "pretrain_model.pth"),
        )

    del model
    trainer = DownstreamTrainer(_D2C, weights)
    trainer.test() if _A.eval_only else trainer.train()


if __name__ == "__main__":
    _A = parser.parse_args()

    # This will launch `main` and set appropriate CUDA device (GPU ID) as
    # per process (accessed in the beginning of `main`).
    dist.launch(
        main,
        num_machines=_A.num_machines,
        num_gpus_per_machine=_A.num_gpus_per_machine,
        machine_rank=_A.machine_rank,
        dist_url=_A.dist_url,
        args=(_A, ),
    )


================================================
FILE: scripts/pretrain_virtex.py
================================================
import argparse
from collections import Counter
from typing import Any

from loguru import logger
import torch
from torch import nn
from torch.cuda import amp
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.tensorboard import SummaryWriter

# fmt: off
from virtex.config import Config
from virtex.factories import (
    PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory,
    LRSchedulerFactory,
)
from virtex.utils.checkpointing import CheckpointManager
from virtex.utils.common import common_parser, common_setup, cycle
import virtex.utils.distributed as dist
from virtex.utils.timer import Timer


parser = common_parser(
    description="Train a VirTex model (CNN + Transformer) on COCO Captions."
)
group = parser.add_argument_group("Checkpointing and Logging")
group.add_argument(
    "--resume-from", default=None,
    help="Path to a checkpoint to resume training from (if provided)."
)
group.add_argument(
    "--checkpoint-every", type=int, default=2000,
    help="Serialize model to a checkpoint after every these many iterations.",
)
group.add_argument(
    "--log-every", type=int, default=20,
    help="""Log training curves to tensorboard after every these many iterations
    only master process logs averaged loss values across processes.""",
)
# fmt: on


def main(_A: argparse.Namespace):

    if _A.num_gpus_per_machine == 0:
        # Set device as CPU if num_gpus_per_machine = 0.
        device: Any = torch.device("cpu")
    else:
        # Get the current device as set for current distributed process.
        # Check `launch` function in `virtex.utils.distributed` module.
        device = torch.cuda.current_device()

    # Create a config object (this will be immutable) and perform common setup
    # such as logging and setting up serialization directory.
    _C = Config(_A.config, _A.config_override)
    common_setup(_C, _A)

    # -------------------------------------------------------------------------
    #   INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER
    # -------------------------------------------------------------------------
    train_dataset = PretrainingDatasetFactory.from_config(_C, split="train")
    val_dataset = PretrainingDatasetFactory.from_config(_C, split="val")

    # Make `DistributedSampler`s to shard datasets across GPU processes.
    # Skip this if training on CPUs.
    train_sampler = (
        DistributedSampler(train_dataset, shuffle=True)  # type: ignore
        if _A.num_gpus_per_machine > 0
        else None
    )
    val_sampler = (
        DistributedSampler(val_dataset, shuffle=False)  # type: ignore
        if _A.num_gpus_per_machine > 0
        else None
    )
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(),
        sampler=train_sampler,
        shuffle=train_sampler is None,
        num_workers=_A.cpu_workers,
        pin_memory=True,
        drop_last=True,
        collate_fn=train_dataset.collate_fn,
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(),
        sampler=val_sampler,
        shuffle=False,
        num_workers=_A.cpu_workers,
        pin_memory=True,
        drop_last=False,
        collate_fn=val_dataset.collate_fn,
    )

    model = PretrainingModelFactory.from_config(_C).to(device)
    optimizer = OptimizerFactory.from_config(_C, model.named_parameters())
    scheduler = LRSchedulerFactory.from_config(_C, optimizer)

    # -------------------------------------------------------------------------
    #   BEFORE TRAINING STARTS
    # -------------------------------------------------------------------------

    # Create a gradient scaler for automatic mixed precision.
    scaler = amp.GradScaler(enabled=_C.AMP)

    # Load checkpoint to resume training if specified.
    if _A.resume_from is not None:
        start_iteration = CheckpointManager(
            model=model, optimizer=optimizer, scheduler=scheduler, scaler=scaler,
        ).load(_A.resume_from)
    else:
        start_iteration = 0

    # Create an iterator from dataloader to sample batches perpetually.
    train_dataloader_iter = cycle(train_dataloader, device, start_iteration)

    # Wrap model in DDP if using more than one processes.
    if dist.get_world_size() > 1:
        dist.synchronize()
        model = nn.parallel.DistributedDataParallel(model, device_ids=[device])

    # Keep track of time per iteration and ETA.
    timer = Timer(
        start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS
    )
    # Create tensorboard writer and checkpoint manager (only in master process).
    if dist.is_master_process():
        tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)
        tensorboard_writer.add_text("config", f"```\n{_C}\n```")

        checkpoint_manager = CheckpointManager(
            _A.serialization_dir,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            scaler=scaler,
        )

    # -------------------------------------------------------------------------
    #   TRAINING LOOP
    # -------------------------------------------------------------------------
    for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1):
        timer.tic()
        optimizer.zero_grad()
        batch = next(train_dataloader_iter)

        with amp.autocast(enabled=_C.AMP):
            output_dict = model(batch)
            loss = output_dict["loss"]

        scaler.scale(loss).backward()

        # First clip norm of gradients, and then perform optimizer step.
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM)
        scaler.step(optimizer)

        scaler.update()
        scheduler.step()
        timer.toc()

        # ---------------------------------------------------------------------
        #   LOGGING
        # ---------------------------------------------------------------------
        if iteration % _A.log_every == 0:
            logger.info(
                f"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]"
            )
            if dist.is_master_process():
                tensorboard_writer.add_scalars(
                    "learning_rate",
                    {
                        "visual": optimizer.param_groups[0]["lr"],
                        "common": optimizer.param_groups[-1]["lr"],
                    },
                    iteration,
                )
                tensorboard_writer.add_scalars(
                    "train", output_dict["loss_components"], iteration
                )

        # ---------------------------------------------------------------------
        #   VALIDATION
        # ---------------------------------------------------------------------
        if iteration % _A.checkpoint_every == 0:
            if dist.is_master_process():
                checkpoint_manager.step(iteration)

            # All processes will wait till master process is done serializing.
            dist.synchronize()

            torch.set_grad_enabled(False)
            model.eval()

            # Accumulate different val loss components according to the type of
            # pretraining model.
            val_loss_counter: Counter = Counter()

            for val_iteration, val_batch in enumerate(val_dataloader, start=1):
                for key in val_batch:
                    val_batch[key] = val_batch[key].to(device)
                output_dict = model(val_batch)

                val_loss_counter.update(output_dict["loss_components"])

            # Divide each loss component by number of val batches per GPU.
            val_loss_dict = {
                k: v / val_iteration for k, v in dict(val_loss_counter).items()
            }
            dist.average_across_processes(val_loss_dict)
            torch.set_grad_enabled(True)
            model.train()

            logger.info(f"Iteration: {iteration} [Val loss: {val_loss_dict}]")
            if dist.is_master_process():
                tensorboard_writer.add_scalars("val", val_loss_dict, iteration)


if __name__ == "__main__":
    _A = parser.parse_args()

    if _A.num_gpus_per_machine == 0:
        main(_A)
    else:
        # This will launch `main` and set appropriate CUDA device (GPU ID) as
        # per process (accessed in the beginning of `main`).
        dist.launch(
            main,
            num_machines=_A.num_machines,
            num_gpus_per_machine=_A.num_gpus_per_machine,
            machine_rank=_A.machine_rank,
            dist_url=_A.dist_url,
            args=(_A, ),
        )


================================================
FILE: setup.py
================================================
#!/usr/bin/env python
import glob
import os
from setuptools import setup
import shutil
from typing import List


def get_model_zoo_configs() -> List[str]:
    """
    Return a list of configs to include in package for model zoo. Copy over
    these configs inside virtex/model_zoo.
    """

    # Use absolute paths while symlinking.
    source_configs_dir = os.path.join(
        os.path.dirname(os.path.realpath(__file__)), "configs"
    )
    destination = os.path.join(
        os.path.dirname(os.path.realpath(__file__)), "virtex", "model_zoo", "configs"
    )
    # Symlink the config directory inside package to have a cleaner pip install.

    # Remove stale symlink/directory from a previous build.
    if os.path.exists(source_configs_dir):
        if os.path.islink(destination):
            os.unlink(destination)
        elif os.path.isdir(destination):
            shutil.rmtree(destination)

    if not os.path.exists(destination):
        try:
            os.symlink(source_configs_dir, destination)
        except OSError:
            # Fall back to copying if symlink fails: ex. on Windows.
            shutil.copytree(source_configs_dir, destination)

    config_paths = glob.glob("configs/**/*.yaml", recursive=True)
    return config_paths


setup(
    name="virtex",
    version="1.4.0",
    author="Karan Desai and Justin Johnson",
    description="VirTex: Learning Visual Representations with Textual Annotations",
    package_data={"virtex.model_zoo": get_model_zoo_configs()},
    python_requires=">=3.8",
    license="MIT",
    zip_safe=True,
)


================================================
FILE: virtex/__init__.py
================================================


================================================
FILE: virtex/config.py
================================================
from typing import Any, List, Optional

from fvcore.common.config import CfgNode as CN


class Config:
    r"""
    This class provides package-wide configuration management. It is a
    nested dict-like structure with nested keys accessible as attributes. It
    contains sensible default values, which can be modified by (first) a YAML
    file and (second) a list of attributes and values.

    An instantiated object is immutable: modifying any attribute is illegal.
    You must override required parameter values either through ``config_file``
    or ``override_list`` arguments.

    Args:
        config_file: Path to a YAML file containing config parameters.
        config_override: A list of sequential attributes and values of parameters.
            This happens after overriding from YAML file.

    Examples:
        Let a YAML file named "config.yaml" specify these parameters to override::

            OPTIM:
            BATCH_SIZE: 512
            LR: 0.01

        >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 1024])
        >>> _C.LR  # default: 0.001
        0.01
        >>> _C.OPTIM.BATCH_SIZE  # default: 256, file: 512
        1024
    """

    def __init__(
        self, config_file: Optional[str] = None, override_list: List[Any] = []
    ):
        _C = CN()

        # Random seed for NumPy and PyTorch, important for reproducibility.
        _C.RANDOM_SEED = 0
        # Train with Automatic Mixed Precision (native PyTorch).
        _C.AMP = True
        # Set CUDNN deterministic flag (torch.backends.cudnn.deterministic).
        # Setting this will ensure exact results on every run at the cost of
        # little slowdown. Good for debugging.
        _C.CUDNN_DETERMINISTIC = False
        # Set CUDNN benchmark flag (torch.backends.cudnn.benchmark). Enables
        # CUDNN to select fastest implementation for operations based on GPU.
        # May change results (in decimals) on different hardware, but faster
        # to train. Turn off while debugging.
        _C.CUDNN_BENCHMARK = True

        # ---------------------------------------------------------------------
        #   Data paths and parameters related to dataloading.
        # ---------------------------------------------------------------------
        _C.DATA = CN()

        # Path to the dataset root, which structure as per README. Path is
        # assumed to be relative to project root.
        _C.DATA.ROOT = "datasets/coco"
        # Path to .model file generated by ``sentencepiece``.
        _C.DATA.TOKENIZER_MODEL = "datasets/vocab/coco_10k.model"

        # Handy config params for vocab size and indices of special tokens.
        # While these can be picked up from the tokenizer, having these in
        # the config makes it easy to create a model without instantiating too
        # many tokenizer instances (especially when not needed, e.g. model zoo).
        # These must match according to what's present in ``TOKENIZER_VOCAB``
        # and ``TOKENIZER_MODEL`` above.
        _C.DATA.VOCAB_SIZE = 10000
        # Index of out-of-vocabulary (and padding) token.
        _C.DATA.UNK_INDEX = 0
        # Index of the start-of-sentence [SOS] token.
        _C.DATA.SOS_INDEX = 1
        # Index of the end-of-sentence [EOS] token.
        _C.DATA.EOS_INDEX = 2
        # Index of the word masking token. While not used for captioning, having
        # this extra token makes it possible to train an MLM model without
        # re-creating a new vocab mapping.
        _C.DATA.MASK_INDEX = 3

        # Size of the image (square) to crop from original input image.
        _C.DATA.IMAGE_CROP_SIZE = 224
        # Maximum length of input caption (number of tokens).
        # Longer captions will be truncated up to this length.
        _C.DATA.MAX_CAPTION_LENGTH = 30

        # List of image transforms (pre-processing and data augmentation) to be
        # applied sequentially (always or randomly) during training and
        # validation. Refer ``virtex/facetories.py`` for all possible transforms.
        _C.DATA.IMAGE_TRANSFORM_TRAIN = [
            "random_resized_crop",
            "horizontal_flip",
            "color_jitter",
            "normalize",
        ]
        _C.DATA.IMAGE_TRANSFORM_VAL = [
            "smallest_resize",
            "center_crop",
            "normalize",
        ]

        # Hyper-parameters for masked LM pretraining task. These are only used
        # when ``MODEL.NAME`` is "masked_lm".
        _C.DATA.MASKED_LM = CN()
        # Fraction of tokens to choose for masking, this must be less than 1.
        _C.DATA.MASKED_LM.MASK_PROPORTION = 0.15
        # Probability to replace chosen tokens with [MASK] token.
        _C.DATA.MASKED_LM.MASK_PROBABILITY = 0.85
        # Probability to replace chosen tokens with a random token.
        _C.DATA.MASKED_LM.REPLACE_PROBABILITY = 0.10

        # ---------------------------------------------------------------------
        #   Model architecture: visual backbone and textual head.
        # ---------------------------------------------------------------------
        _C.MODEL = CN()

        # Name of model, based on pretraining task.
        # Possible choices: {"token_classification", "multilabel_classification",
        # "captioning", "bicaptioning", "masked_lm", "virtex"}
        _C.MODEL.NAME = "virtex"

        _C.MODEL.VISUAL = CN()
        # Name of visual backbone. Possible choices: {"blind", "torchvision"}
        # Models from torchvision can be specified as shown below.
        _C.MODEL.VISUAL.NAME = "torchvision::resnet50"
        # Number of channels in pooled spatial features of visual backbone.
        _C.MODEL.VISUAL.FEATURE_SIZE = 2048
        # Whether to load ImageNet pretrained weights into visual backbone.
        _C.MODEL.VISUAL.PRETRAINED = False
        # Whether to keep visual backbone frozen and train only textual head.
        _C.MODEL.VISUAL.FROZEN = False

        _C.MODEL.TEXTUAL = CN()
        # Name of textual head. Set to "none" for MODEL.NAME = "*_classification".
        # Possible choices: {"transdec_postnorm", "transdec_prenorm"}.
        # Architectural hyper-parameters are specified as shown above.
        _C.MODEL.TEXTUAL.NAME = "transdec_postnorm::L1_H2048_A32_F8192"
        # L = Number of layers in the transformer.
        # H = Hidden size of the transformer (embeddings, attention features).
        # A = Number of attention heads in the transformer.
        # F = Size of feedforward layers in the transformer.
        # Typically, we have (A = H / 64) and (F = 4 * H).

        # Dropout probability for embedding, hidden features in textual head.
        _C.MODEL.TEXTUAL.DROPOUT = 0.1

        _C.MODEL.DECODER = CN()
        # What algorithm to use for decoding. Supported values: {"beam_search",
        # "nucleus_sampling"}.
        _C.MODEL.DECODER.NAME = "beam_search"
        # Number of beams to decode (1 = greedy decoding). Ignored when decoding
        # through nucleus sampling.
        _C.MODEL.DECODER.BEAM_SIZE = 5
        # Size of nucleus for sampling predictions. Ignored when decoding through
        # beam search.
        _C.MODEL.DECODER.NUCLEUS_SIZE = 0.9
        # Maximum length of decoded caption. Decoding may end earlier when [EOS]
        # token is sampled.
        _C.MODEL.DECODER.MAX_DECODING_STEPS = _C.DATA.MAX_CAPTION_LENGTH

        # ---------------------------------------------------------------------
        #   Optimization hyper-parameters, default values are for pretraining
        #   our best model on bicaptioning task (COCO Captions).
        # ---------------------------------------------------------------------
        _C.OPTIM = CN()

        # Name of optimizer to use. Supported values: {"sgd", "adamw"}.
        # AdamW uses default (beta1, beta2) values from PyTorch.
        _C.OPTIM.OPTIMIZER_NAME = "sgd"
        # Momentum co-efficient for SGD. Ignored for AdamW.
        _C.OPTIM.SGD_MOMENTUM = 0.9
        # Weight decay co-efficient for the optimizer.
        _C.OPTIM.WEIGHT_DECAY = 0.0001
        # Regex pattern of params for which there will be no weight decay.
        _C.OPTIM.NO_DECAY = ".*textual.(embedding|transformer).*(norm.*|bias)"
        # Max gradient norm for clipping to avoid exploding gradients.
        _C.OPTIM.CLIP_GRAD_NORM = 10.0

        # Wrap our optimizer with Lookahead (https://arxiv.org/abs/1907.08610).
        _C.OPTIM.LOOKAHEAD = CN()
        _C.OPTIM.LOOKAHEAD.USE = True
        _C.OPTIM.LOOKAHEAD.ALPHA = 0.5
        _C.OPTIM.LOOKAHEAD.STEPS = 5

        # We set different learning rates for CNN (visual backbone) and rest of
        # the model. CNN LR is typically much higher for training from scratch.
        # Both LRs undergo same warmup-decay schedules.

        # Total batch size (will be distributed evenly across GPUs).
        _C.OPTIM.BATCH_SIZE = 256
        # Max learning rate for CNN (visual backbone).
        _C.OPTIM.CNN_LR = 0.2
        # Max learning rate for rest of the model.
        _C.OPTIM.LR = 0.001
        # Number of iterations to train for, batches are randomly sampled.
        _C.OPTIM.NUM_ITERATIONS = 500000

        # Number of steps at the start of training for linear LR warmup.
        _C.OPTIM.WARMUP_STEPS = 10000
        # Learning rate annealing schedule for decay after warmup.
        # Possible choices: {"none", "linear", "cosine", "multistep"}.
        _C.OPTIM.LR_DECAY_NAME = "cosine"
        # Steps to decay LR for "multistep" schedule.
        _C.OPTIM.LR_STEPS = []
        # Factor to multiply with LR for "multistep" schedule.
        _C.OPTIM.LR_GAMMA = 0.1

        # Override parameter values from YAML file first, then from override
        # list, then add derived params.
        self._C = _C
        if config_file is not None:
            self._C.merge_from_file(config_file)
        self._C.merge_from_list(override_list)

        # Make an instantiated object of this class immutable.
        self._C.freeze()

    def dump(self, file_path: str):
        r"""Save config at the specified file path.

        Args:
            file_path: Path to save config file (YAML).
        """
        self._C.dump(stream=open(file_path, "w"))

    def __getattr__(self, attr: str):
        return self._C.__getattr__(attr)

    def __str__(self):
        return self._C.__str__()

    def __repr__(self):
        return self._C.__repr__()


================================================
FILE: virtex/data/__init__.py
================================================
from .datasets.captioning import CaptioningDataset
from .datasets.classification import (
    TokenClassificationDataset,
    MultiLabelClassificationDataset,
)
from .datasets.masked_lm import MaskedLmDataset
from .datasets.downstream import (
    ImageNetDataset,
    INaturalist2018Dataset,
    VOC07ClassificationDataset,
    ImageDirectoryDataset,
)

__all__ = [
    "CocoCaptionsDataset",
    "CaptioningDataset",
    "TokenClassificationDataset",
    "MultiLabelClassificationDataset",
    "MaskedLmDataset",
    "ImageDirectoryDataset",
    "ImageNetDataset",
    "INaturalist2018Dataset",
    "VOC07ClassificationDataset",
]


================================================
FILE: virtex/data/datasets/captioning.py
================================================
import random
from typing import Callable, Dict, List

import numpy as np
import torch
from torch.utils.data import Dataset

from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.data import transforms as T
from .coco_captions import CocoCaptionsDataset


class CaptioningDataset(Dataset):
    r"""
    A dataset which provides image-caption (forward and backward) pairs from
    a COCO Captions annotation file. This is used for pretraining tasks which
    use captions - bicaptioning, forward captioning and token classification.

    Args:
        data_root: Path to dataset directory containing images and annotations.
        split: Name of COCO 2017 split to read. One of ``{"train", "val"}``.
        tokenizer: Tokenizer which maps word tokens to their integer IDs.
        image_transform: List of image transformations, from either
            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_
            or :mod:`virtex.data.transforms`.
        max_caption_length: Maximum number of tokens to keep in caption tokens.
            Extra tokens will be trimmed from the right end of the token list.
    """

    def __init__(
        self,
        data_root: str,
        split: str,
        tokenizer: SentencePieceBPETokenizer,
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
        max_caption_length: int = 30,
    ):
        self._dset = CocoCaptionsDataset(data_root, split)
        self.tokenizer = tokenizer
        self.image_transform = image_transform
        self.max_caption_length = max_caption_length

        # Short handles for common tokens for convenience:
        self.padding_idx = tokenizer.token_to_id("<unk>")
        self.sos_id = tokenizer.token_to_id("[SOS]")
        self.eos_id = tokenizer.token_to_id("[EOS]")

    def __len__(self):
        return len(self._dset)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:

        # keys: {"image_id", "image", "captions"}
        instance = self._dset[idx]
        image_id, image, captions = (
            instance["image_id"],
            instance["image"],
            instance["captions"],
        )
        caption = random.choice(captions)

        # Transform image-caption pair and convert image from HWC to CHW format.
        # Pass in caption to image_transform due to paired horizontal flip.
        # Caption won't be tokenized/processed here.
        image_caption = self.image_transform(image=image, caption=caption)
        image, caption = image_caption["image"], image_caption["caption"]
        image = np.transpose(image, (2, 0, 1))

        caption_tokens = [self.sos_id, *self.tokenizer.encode(caption), self.eos_id]
        caption_tokens = caption_tokens[: self.max_caption_length]
        return {
            "image_id": torch.tensor(image_id, dtype=torch.long),
            "image": torch.tensor(image, dtype=torch.float),
            "caption_tokens": torch.tensor(caption_tokens, dtype=torch.long),
            "noitpac_tokens": torch.tensor(caption_tokens, dtype=torch.long).flip(0),
            "caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long),
        }

    def collate_fn(
        self, data: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:

        # Pad `caption_tokens` and `masked_labels` up to this length.
        caption_tokens = torch.nn.utils.rnn.pad_sequence(
            [d["caption_tokens"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        noitpac_tokens = torch.nn.utils.rnn.pad_sequence(
            [d["noitpac_tokens"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        return {
            "image_id": torch.stack([d["image_id"] for d in data], dim=0),
            "image": torch.stack([d["image"] for d in data], dim=0),
            "caption_tokens": caption_tokens,
            "noitpac_tokens": noitpac_tokens,
            "caption_lengths": torch.stack([d["caption_lengths"] for d in data]),
        }


================================================
FILE: virtex/data/datasets/classification.py
================================================
from collections import defaultdict
import glob
import json
import os
import random
from typing import Any, Callable, Dict, List, Tuple

import albumentations as alb
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset

from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.data import transforms as T
from .coco_captions import CocoCaptionsDataset


class TokenClassificationDataset(Dataset):
    r"""
    A dataset which provides image-labelset pairs from a COCO Captions annotation
    file. The set of caption tokens (unordered) is treated as a labelset.

    Args:
        data_root: Path to dataset directory containing images and annotations.
        split: Name of COCO 2017 split to read. One of ``{"train", "val"}``.
        tokenizer: Tokenizer which maps word tokens to their integer IDs.
        image_transform: List of image transformations, from either
            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_
            or :mod:`virtex.data.transforms`.
        max_caption_length: Maximum number of tokens to keep in caption tokens.
            Extra tokens will be trimmed from the right end of the token list.
    """

    def __init__(
        self,
        data_root: str,
        split: str,
        tokenizer: SentencePieceBPETokenizer,
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
        max_caption_length: int = 30,
    ):
        self._dset = CocoCaptionsDataset(data_root, split)
        self.image_transform = image_transform
        self.max_caption_length = max_caption_length

        # Short handles for common tokens for convenience:
        self.padding_idx = tokenizer.token_to_id("<unk>")
        self.sos_id = tokenizer.token_to_id("[SOS]")
        self.eos_id = tokenizer.token_to_id("[EOS]")

    def __len__(self):
        return len(self._dset)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:

        # keys: {"image_id", "image", "captions"}
        instance = self._dset[idx]
        image_id, image, captions = (
            instance["image_id"],
            instance["image"],
            instance["captions"],
        )
        caption = random.choice(captions)

        # Transform image-caption pair and convert image from HWC to CHW format.
        # Pass in caption to image_transform due to paired horizontal flip.
        # Caption won't be tokenized/processed here.
        image_caption = self.image_transform(image=image, caption=caption)
        image, caption = image_caption["image"], image_caption["caption"]
        image = np.transpose(image, (2, 0, 1))

        caption_tokens = [self.sos_id, *self.tokenizer.encode(caption), self.eos_id]
        caption_tokens = caption_tokens[: self.max_caption_length]
        return {
            "image_id": torch.tensor(image_id, dtype=torch.long),
            "image": torch.tensor(image, dtype=torch.float),
            "labels": torch.tensor(caption_tokens, dtype=torch.long),
        }

    def collate_fn(
        self, data: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:

        labels = torch.nn.utils.rnn.pad_sequence(
            [d["labels"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        return {
            "image_id": torch.stack([d["image_id"] for d in data], dim=0),
            "image": torch.stack([d["image"] for d in data], dim=0),
            "labels": labels,
        }


class MultiLabelClassificationDataset(Dataset):
    r"""
    A dataset which provides image-labelset pairs from COCO instance annotation
    files. This is used for multilabel classification pretraining task.

    Args:
        data_root: Path to dataset directory containing images and annotations.
        split: Name of COCO 2017 split to read. One of ``{"train", "val"}``.
        image_transform: List of image transformations, from either
            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_
            or :mod:`virtex.data.transforms`.
    """

    def __init__(
        self,
        data_root: str,
        split: str,
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
    ):
        self.image_transform = image_transform

        # Make a tuple of image id and its filename, get image_id from its
        # filename (assuming directory has images with names in COCO 2017 format).
        image_filenames = glob.glob(os.path.join(data_root, f"{split}2017", "*.jpg"))
        self.id_filename: List[Tuple[int, str]] = [
            (int(os.path.basename(name)[:-4]), name) for name in image_filenames
        ]
        # Load the instance (bounding box and mask) annotations.
        _annotations = json.load(
            open(os.path.join(data_root, "annotations", f"instances_{split}2017.json"))
        )
        # Make a mapping between COCO category id and its index, to make IDs
        # consecutive, else COCO has 80 classes with IDs 1-90. Start index from 1
        # as 0 is reserved for background (padding idx).
        _category_ids = {
            ann["id"]: index + 1 for index, ann in enumerate(_annotations["categories"])
        }
        # Mapping from image ID to list of unique category IDs (indices as above)
        # in corresponding image.
        self._labels: Dict[str, Any] = defaultdict(list)

        for ann in _annotations["annotations"]:
            self._labels[ann["image_id"]].append(_category_ids[ann["category_id"]])

        # De-duplicate and drop empty labels, we only need to do classification.
        self._labels = {
            _id: list(set(lbl)) for _id, lbl in self._labels.items() if len(lbl) > 0
        }
        # Filter out image IDs which didn't have any labels.
        self.id_filename = [
            (t[0], t[1]) for t in self.id_filename if t[0] in self._labels
        ]
        # Padding while forming a batch, because images may have variable number
        # of instances. We do not need padding index from tokenizer: COCO has
        # category ID 0 as background, conventionally.
        self.padding_idx = 0

    def __len__(self):
        return len(self.id_filename)

    def __getitem__(self, idx: int):
        # Get image ID and filename.
        image_id, filename = self.id_filename[idx]

        # Open image from path and apply transformation, convert to CHW format.
        image = cv2.imread(filename)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.image_transform(image=image)["image"]
        image = np.transpose(image, (2, 0, 1))

        # Get a list of instances present in the image.
        labels = self._labels[image_id]

        return {
            "image_id": torch.tensor(image_id, dtype=torch.long),
            "image": torch.tensor(image, dtype=torch.float),
            "labels": torch.tensor(labels, dtype=torch.long),
        }

    def collate_fn(
        self, data: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:

        labels = torch.nn.utils.rnn.pad_sequence(
            [d["labels"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        return {
            "image_id": torch.stack([d["image_id"] for d in data], dim=0),
            "image": torch.stack([d["image"] for d in data], dim=0),
            "labels": labels,
        }


================================================
FILE: virtex/data/datasets/coco_captions.py
================================================
from collections import defaultdict
import json
import os
import unicodedata
from typing import Dict, List

import cv2
from torch.utils.data import Dataset


class CocoCaptionsDataset(Dataset):
    r"""
    A PyTorch dataset to read COCO Captions dataset and provide it completely
    unprocessed. This dataset is used by various task-specific datasets
    in :mod:`~virtex.data.datasets` module.

    Args:
        data_root: Path to the COCO dataset root directory.
        split: Name of COCO 2017 split to read. One of ``{"train", "val"}``.
    """

    def __init__(self, data_root: str, split: str):

        # Get paths to image directory and annotation file.
        image_dir = os.path.join(data_root, f"{split}2017")
        captions = json.load(
            open(os.path.join(data_root, "annotations", f"captions_{split}2017.json"))
        )
        # Collect list of captions for each image.
        captions_per_image: Dict[int, List[str]] = defaultdict(list)

        for ann in captions["annotations"]:
            # Perform common normalization (lowercase, trim spaces, NKFC strip
            # accents and NKFC normalization).
            caption = ann["caption"].lower()
            caption = unicodedata.normalize("NFKD", caption)
            caption = "".join([chr for chr in caption if not unicodedata.combining(chr)])

            captions_per_image[ann["image_id"]].append(caption)

        # Collect image file for each image (by its ID).
        image_filepaths: Dict[int, str] = {
            im["id"]: os.path.join(image_dir, im["file_name"])
            for im in captions["images"]
        }
        # Keep all annotations in memory. Make a list of tuples, each tuple
        # is ``(image_id, file_path, list[captions])``.
        self.instances = [
            (im_id, image_filepaths[im_id], captions_per_image[im_id])
            for im_id in captions_per_image.keys()
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_id, image_path, captions = self.instances[idx]

        # shape: (height, width, channels), dtype: uint8
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        return {"image_id": image_id, "image": image, "captions": captions}


================================================
FILE: virtex/data/datasets/downstream.py
================================================
from collections import defaultdict
import glob
import json
import os
from typing import Callable, Dict, List, Tuple

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageNet

from virtex.data import transforms as T


class ImageNetDataset(ImageNet):
    r"""
    Simple wrapper over torchvision's ImageNet dataset. Image transform is
    handled here instead of passing to super class.

    Args:
        data_root: Path to the ImageNet dataset directory.
        split: Which split to read from. One of ``{"train", "val"}``.
        image_transform: List of image transformations, from either
            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_
            or :mod:`virtex.data.transforms`.
    """

    def __init__(
        self,
        data_root: str = "datasets/imagenet",
        split: str = "train",
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
    ):
        super().__init__(data_root, split)
        self.image_transform = image_transform

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        image, label = super().__getitem__(idx)

        # Apply transformation to  image and convert to CHW format.
        image = self.image_transform(image=np.array(image))["image"]
        image = np.transpose(image, (2, 0, 1))
        return {
            "image": torch.tensor(image, dtype=torch.float),
            "label": torch.tensor(label, dtype=torch.long),
        }

    @staticmethod
    def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        return {
            "image": torch.stack([d["image"] for d in data], dim=0),
            "label": torch.stack([d["label"] for d in data], dim=0),
        }


class INaturalist2018Dataset(Dataset):
    r"""
    A dataset which provides image-label pairs from the iNaturalist 2018 dataset.

    Args:
        data_root: Path to the iNaturalist 2018 dataset directory.
        split: Which split to read from. One of ``{"train", "val"}``.
        image_transform: List of image transformations, from either
            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_
            or :mod:`virtex.data.transforms`.
    """

    def __init__(
        self,
        data_root: str = "datasets/inaturalist",
        split: str = "train",
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
    ):
        self.split = split
        self.image_transform = image_transform

        annotations = json.load(
            open(os.path.join(data_root, "annotations", f"{split}2018.json"))
        )
        # Make a list of image IDs to file paths.
        self.image_id_to_file_path = {
            ann["id"]: os.path.join(data_root, ann["file_name"])
            for ann in annotations["images"]
        }
        # For a list of instances: (image_id, category_id) tuples.
        self.instances = [
            (ann["image_id"], ann["category_id"])
            for ann in annotations["annotations"]
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_id, label = self.instances[idx]
        image_path = self.image_id_to_file_path[image_id]

        # Open image from path and apply transformation, convert to CHW format.
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.image_transform(image=image)["image"]
        image = np.transpose(image, (2, 0, 1))

        return {
            "image": torch.tensor(image, dtype=torch.float),
            "label": torch.tensor(label, dtype=torch.long),
        }

    @staticmethod
    def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        return {
            "image": torch.stack([d["image"] for d in data], dim=0),
            "label": torch.stack([d["label"] for d in data], dim=0),
        }


class VOC07ClassificationDataset(Dataset):
    r"""
    A dataset which provides image-label pairs from the PASCAL VOC 2007 dataset.

    Args:
        data_root: Path to VOC 2007 directory containing sub-directories named
            ``Annotations``, ``ImageSets``, and ``JPEGImages``.
        split: Which split to read from. One of ``{"trainval", "test"}``.
        image_transform: List of image transformations, from either
            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_
            or :mod:`virtex.data.transforms`.
    """

    def __init__(
        self,
        data_root: str = "datasets/VOC2007",
        split: str = "trainval",
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
    ):
        self.split = split
        self.image_transform = image_transform

        ann_paths = sorted(
            glob.glob(os.path.join(data_root, "ImageSets", "Main", f"*_{split}.txt"))
        )
        # A list like; ["aeroplane", "bicycle", "bird", ...]
        self.class_names = [
            os.path.basename(path).split("_")[0] for path in ann_paths
        ]

        # We will construct a map for image name to a list of
        # shape: (num_classes, ) and values as one of {-1, 0, 1}.
        # 1: present, -1: not present, 0: ignore.
        image_names_to_labels: Dict[str, torch.Tensor] = defaultdict(
            lambda: -torch.ones(len(self.class_names), dtype=torch.int32)
        )
        for cls_num, ann_path in enumerate(ann_paths):
            with open(ann_path, "r") as fopen:
                for line in fopen:
                    img_name, orig_label_str = line.strip().split()
                    orig_label = int(orig_label_str)

                    # In VOC data, -1 (not present): set to 0 as train target
                    # In VOC data, 0 (ignore): set to -1 as train target.
                    orig_label = (
                        0 if orig_label == -1 else -1 if orig_label == 0 else 1
                    )
                    image_names_to_labels[img_name][cls_num] = orig_label

        # Convert the dict to a list of tuples for easy indexing.
        # Replace image name with full image path.
        self.instances: List[Tuple[str, torch.Tensor]] = [
            (
                os.path.join(data_root, "JPEGImages", f"{image_name}.jpg"),
                label.tolist(),
            )
            for image_name, label in image_names_to_labels.items()
        ]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx: int):
        image_path, label = self.instances[idx]

        # Open image from path and apply transformation, convert to CHW format.
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.image_transform(image=image)["image"]
        image = np.transpose(image, (2, 0, 1))

        return {
            "image": torch.tensor(image, dtype=torch.float),
            "label": torch.tensor(label, dtype=torch.long),
        }

    @staticmethod
    def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        return {
            "image": torch.stack([d["image"] for d in data], dim=0),
            "label": torch.stack([d["label"] for d in data], dim=0),
        }


class ImageDirectoryDataset(Dataset):
    r"""
    A dataset which reads images from any directory. This class is useful to
    run image captioning inference on our models with any arbitrary images.

    Args:
        data_root: Path to a directory containing images.
        image_transform: List of image transformations, from either
            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_
            or :mod:`virtex.data.transforms`.
    """

    def __init__(
        self, data_root: str, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM
    ):
        self.image_paths = glob.glob(os.path.join(data_root, "*"))
        self.image_transform = image_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        image_path = self.image_paths[idx]
        # Remove extension from image name to use as image_id.
        image_id = os.path.splitext(os.path.basename(image_path))[0]

        # Open image from path and apply transformation, convert to CHW format.
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.image_transform(image=image)["image"]
        image = np.transpose(image, (2, 0, 1))

        # Return image id as string so collate_fn does not cast to torch.tensor.
        return {"image_id": str(image_id), "image": torch.tensor(image)}


================================================
FILE: virtex/data/datasets/masked_lm.py
================================================
import math
import random
from typing import Callable, Dict, List

import albumentations as alb
import numpy as np
import torch
from torch.utils.data import Dataset

from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.data import transforms as T
from .coco_captions import CocoCaptionsDataset


class MaskedLmDataset(Dataset):
    def __init__(
        self,
        data_root: str,
        split: str,
        tokenizer: SentencePieceBPETokenizer,
        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,
        max_caption_length: int = 30,
        mask_proportion: float = 0.15,
        mask_probability: float = 0.80,
        replace_probability: float = 0.10,
    ):
        self._dset = CocoCaptionsDataset(data_root, split)
        self.tokenizer = tokenizer
        self.image_transform = image_transform
        self.max_caption_length = max_caption_length

        # Short handles for common tokens for convenience:
        self.padding_idx = tokenizer.token_to_id("<unk>")
        self.sos_id = tokenizer.token_to_id("[SOS]")
        self.eos_id = tokenizer.token_to_id("[EOS]")
        self.mask_id = tokenizer.token_to_id("[MASK]")

        self._vocab_size = tokenizer.get_vocab_size()
        self._mask_proportion = mask_proportion
        self._mask_prob = mask_probability
        self._repl_prob = replace_probability

    def __len__(self):
        return len(self._dset)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:

        # keys: {"image_id", "image", "captions"}
        instance = self._dset[idx]
        image_id, image, captions = (
            instance["image_id"],
            instance["image"],
            instance["captions"],
        )
        caption = random.choice(captions)

        # Transform image-caption pair and convert image from HWC to CHW format.
        # Pass in caption to image_transform due to paired horizontal flip.
        # Caption won't be tokenized/processed here.
        image_caption = self.image_transform(image=image, caption=caption)
        image, caption = image_caption["image"], image_caption["caption"]
        image = np.transpose(image, (2, 0, 1))

        caption_tokens = [self.sos_id, *self.tokenizer.encode(caption), self.eos_id]
        caption_tokens = caption_tokens[: self.max_caption_length]
        # ---------------------------------------------------------------------
        #  Mask some tokens randomly.
        # ---------------------------------------------------------------------
        masked_labels = [self.padding_idx] * len(caption_tokens)

        # Indices in `caption_tokens` list to mask (minimum 1 index).
        # Leave out first and last indices (boundary tokens).
        tokens_to_mask: List[int] = random.sample(
            list(range(1, len(caption_tokens) - 1)),
            math.ceil((len(caption_tokens) - 2) * self._mask_proportion),
        )
        for i in tokens_to_mask:
            # Whether to replace with [MASK] or random word.
            # If only one token, always [MASK].
            if len(tokens_to_mask) == 1:
                masked_labels[i] = caption_tokens[i]
                caption_tokens[i] = self.mask_id
            else:
                _flag: float = random.random()
                if _flag <= self._mask_prob + self._repl_prob:
                    if _flag <= self._mask_prob:
                        masked_labels[i] = caption_tokens[i]
                        caption_tokens[i] = self.mask_id
                    else:
                        caption_tokens[i] = self._random_token_index()
        # ---------------------------------------------------------------------

        return {
            "image_id": torch.tensor(image_id, dtype=torch.long),
            "image": torch.tensor(image, dtype=torch.float),
            "caption_tokens": torch.tensor(caption_tokens, dtype=torch.long),
            "masked_labels": torch.tensor(masked_labels, dtype=torch.long),
            "caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long),
        }

    def collate_fn(
        self, data: List[Dict[str, torch.Tensor]]
    ) -> Dict[str, torch.Tensor]:

        # Pad `caption_tokens` and `masked_labels` up to this length.
        caption_tokens = torch.nn.utils.rnn.pad_sequence(
            [d["caption_tokens"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        masked_labels = torch.nn.utils.rnn.pad_sequence(
            [d["masked_labels"] for d in data],
            batch_first=True,
            padding_value=self.padding_idx,
        )
        return {
            "image_id": torch.stack([d["image_id"] for d in data], dim=0),
            "image": torch.stack([d["image"] for d in data], dim=0),
            "caption_tokens": caption_tokens,
            "masked_labels": masked_labels,
            "caption_lengths": torch.stack([d["caption_lengths"] for d in data]),
        }

    def _random_token_index(self) -> int:
        return random.randint(0, self._vocab_size - 1)


================================================
FILE: virtex/data/tokenizers.py
================================================
from typing import Any, Dict, List

import sentencepiece as sp


class SentencePieceBPETokenizer:
    r"""
    A tokenizer based on `SentencePiece <https://github.com/google/sentencepiece>`_
    with BPE sub-routine. It encodes caption strings into list of tokens.

    Args:
        model_path: Path to the ``.model`` file trained by SentencePiece.
    """
    SP_SPACE = u"▁"

    def __init__(self, model_path: str):
        self.model_path = model_path

        # Load pretrained tokenizer model.
        self.model = sp.SentencePieceProcessor()
        self.model.Load(model_path)

    def __getstate__(self):
        r"""
        This magic method, along with ``__setstate__`` makes an object of this
        class picklable (and usable while data loading with multiple workers).
        """
        state_dict = self.__dict__.copy()
        state_dict["model"] = None
        return state_dict

    def __setstate__(self, state_dict: Dict[str, Any]):
        self.__dict__ = state_dict

        self.model = sp.SentencePieceProcessor()
        self.model.Load(self.model_path)

    def get_vocab_size(self) -> int:
        r"""Return number of tokens in vocabulary (including special tokens)."""
        return len(self.model)

    def token_to_id(self, token: str) -> int:
        r"""Get integer ID of a string token (``<unk>`` if does not exist)."""
        # Since tokenizer uses subword regularization, one token may break down to multiple IDs.
        # Keep trying till we get a single ID.
        return self.model.piece_to_id(token)

    def id_to_token(self, token_id: int) -> str:
        r"""Get string token of an integer ID (``<unk>`` if does not exist)."""
        return self.model.id_to_piece(token_id)

    def encode(self, text: str) -> List[int]:
        r"""Convert a text string to a list of integer token ids."""
        return self.model.EncodeAsIds(text)

    def decode(self, token_ids: List[int]) -> str:
        r"""Convert a sequence of token IDs to a text string."""
        return self.model.DecodeIds(token_ids)


================================================
FILE: virtex/data/transforms.py
================================================
import albumentations as alb
import cv2


class HorizontalFlip(alb.BasicTransform):
    r"""
    Flip the image horizontally randomly (equally likely) and replace the
    word "left" with "right" in the caption.

    .. note::

        This transform can also work on images only (without the captions).
        Its behavior will be same as albumentations
        :class:`~albumentations.augmentations.transforms.HorizontalFlip`.

    Examples:
        >>> flip = HorizontalFlip(p=0.5)
        >>> out1 = flip(image=image, caption=caption)  # keys: {"image", "caption"}
        >>> # Also works with images (without caption).
        >>> out2 = flip(image=image)  # keys: {"image"}

    """

    @property
    def targets(self):
        return {"image": self.apply, "caption": self.apply_to_caption}

    def apply(self, img, **params):
        return cv2.flip(img, 1)

    def apply_to_caption(self, caption, **params):
        caption = (
            caption.replace("left", "[TMP]")
            .replace("right", "left")
            .replace("[TMP]", "right")
        )
        return caption


class RandomResizedSquareCrop(alb.RandomResizedCrop):
    r"""
    A variant of :class:`albumentations.augmentations.transforms.RandomResizedCrop`
    which assumes a square crop (width = height). Everything else is same.

    Args:
        size: Dimension of the width and height of the cropped image.
    """

    def __init__(self, size: int, *args, **kwargs):
        super().__init__(height=size, width=size, *args, **kwargs)


class CenterSquareCrop(alb.CenterCrop):
    r"""
    A variant of :class:`albumentations.augmentations.transforms.CenterCrop`
    which assumes a square crop (width = height). Everything else is same.

    Args:
        size: Dimension of the width and height of the cropped image.
    """

    def __init__(self, size: int, *args, **kwargs):
        super().__init__(height=size, width=size, *args, **kwargs)


class SquareResize(alb.Resize):
    r"""
    A variant of :class:`albumentations.augmentations.transforms.Resize` which
    assumes a square resize (width = height). Everything else is same.

    Args:
        size: Dimension of the width and height of the cropped image.
    """

    def __init__(self, size: int, *args, **kwargs):
        super().__init__(height=size, width=size, *args, **kwargs)


# =============================================================================
#   SOME COMMON CONSTANTS AND IMAGE TRANSFORMS:
#   These serve as references here, and are used as default params in many
#   dataset class constructors.
# -----------------------------------------------------------------------------

IMAGENET_COLOR_MEAN = (0.485, 0.456, 0.406)
r"""ImageNet color normalization mean in RGB format (values in 0-1)."""

IMAGENET_COLOR_STD = (0.229, 0.224, 0.225)
r"""ImageNet color normalization std in RGB format (values in 0-1)."""

DEFAULT_IMAGE_TRANSFORM = alb.Compose(
    [
        alb.SmallestMaxSize(256, p=1.0),
        CenterSquareCrop(224, p=1.0),
        alb.Normalize(mean=IMAGENET_COLOR_MEAN, std=IMAGENET_COLOR_STD, p=1.0),
    ]
)
r"""Default transform without any data augmentation (during pretraining)."""
# =============================================================================


================================================
FILE: virtex/factories.py
================================================
r"""
This module is a collection of *factories* for creating objects of datasets,
models, optimizers and other useful components. For example, a ResNet-50
visual backbone can be created as:

    .. code-block:: python

        >>> # Explicitly by name, args and kwargs:
        >>> backbone = VisualBackboneFactory.create(
        ...     "torchvision::resnet50", pretrained=False
        ... )
        >>> # Directly from a config object:
        >>> _C = Config(override_list=["MODEL.VISUAL.NAME", "torchvision::resnet50"])
        >>> backbone = VisualBackboneFactory.from_config(_C)

Creating directly from :class:`~virtex.config.Config` is fast and simple, and
ensures minimal changes throughout the codebase upon any change in the call
signature of underlying class; or config hierarchy. Refer description of
specific factories for more details.
"""
import re
from functools import partial
from typing import Any, Callable, Dict, Iterable, List

import albumentations as alb
from torch import nn, optim

import virtex.data as vdata
import virtex.models as vmodels
from virtex.config import Config
from virtex.data import transforms as T
from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.modules import visual_backbones, textual_heads
from virtex.optim import Lookahead, lr_scheduler

from virtex.utils.beam_search import AutoRegressiveBeamSearch
from virtex.utils.nucleus_sampling import AutoRegressiveNucleusSampling


class Factory:
    r"""
    Base class for all factories. All factories must inherit this base class
    and follow these guidelines for a consistent behavior:

    * Factory objects cannot be instantiated, doing ``factory = SomeFactory()``
      is illegal. Child classes should not implement ``__init__`` methods.
    * All factories must have an attribute named ``PRODUCTS`` of type
      ``Dict[str, Callable]``, which associates each class with a unique string
      name which can be used to create it.
    * All factories must implement one classmethod, :meth:`from_config` which
      contains logic for creating an object directly by taking name and other
      arguments directly from :class:`~virtex.config.Config`. They can use
      :meth:`create` already implemented in this base class.
    * :meth:`from_config` should not use too many extra arguments than the
      config itself, unless necessary (such as model parameters for optimizer).
    """

    PRODUCTS: Dict[str, Callable] = {}

    def __init__(self):
        raise ValueError(
            f"""Cannot instantiate {self.__class__.__name__} object, use
            `create` classmethod to create a product from this factory.
            """
        )

    @classmethod
    def create(cls, name: str, *args, **kwargs) -> Any:
        r"""Create an object by its name, args and kwargs."""
        if name not in cls.PRODUCTS:
            raise KeyError(f"{cls.__class__.__name__} cannot create {name}.")

        return cls.PRODUCTS[name](*args, **kwargs)

    @classmethod
    def from_config(cls, config: Config) -> Any:
        r"""Create an object directly from config."""
        raise NotImplementedError


class TokenizerFactory(Factory):
    r"""
    Factory to create text tokenizers. This codebase ony supports one tokenizer
    for now, but having a dedicated factory makes it easy to add more if needed.

    Possible choices: ``{"SentencePieceBPETokenizer"}``.
    """

    PRODUCTS: Dict[str, Callable] = {
        "SentencePieceBPETokenizer": SentencePieceBPETokenizer
    }

    @classmethod
    def from_config(cls, config: Config) -> SentencePieceBPETokenizer:
        r"""
        Create a tokenizer directly from config.

        Args:
            config: Config object with all the parameters.
        """

        _C = config

        tokenizer = cls.create(
            "SentencePieceBPETokenizer",
            model_path=_C.DATA.TOKENIZER_MODEL,
        )
        return tokenizer


class ImageTransformsFactory(Factory):
    r"""
    Factory to create image transformations for common preprocessing and data
    augmentations. These are a mix of default transformations from
    `albumentations <https://albumentations.readthedocs.io/en/latest/>`_ and
    some extended ones defined in :mod:`virtex.data.transforms`.

    It uses sensible default values, however they can be provided with the name
    in dict syntax. Example: ``random_resized_crop::{'scale': (0.08, 1.0)}``

    .. note::

        This factory does not implement :meth:`from_config` method. It is only
        used by :class:`PretrainingDatasetFactory` and
        :class:`DownstreamDatasetFactory`.

    Possible choices: ``{"center_crop", "horizontal_flip", "random_resized_crop",
    "normalize", "global_resize", "color_jitter", "smallest_resize"}``.
    """

    # fmt: off
    PRODUCTS: Dict[str, Callable] = {
        # Input resize transforms: whenever selected, these are always applied.
        # These transforms require one position argument: image dimension.
        "random_resized_crop": partial(
            T.RandomResizedSquareCrop, scale=(0.2, 1.0), ratio=(0.75, 1.333), p=1.0
        ),
        "center_crop": partial(T.CenterSquareCrop, p=1.0),
        "smallest_resize": partial(alb.SmallestMaxSize, p=1.0),
        "global_resize": partial(T.SquareResize, p=1.0),

        # Keep hue limits small in color jitter because it changes color drastically
        # and captions often mention colors. Apply with higher probability.
        "color_jitter": partial(
            alb.ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8
        ),
        "horizontal_flip": partial(T.HorizontalFlip, p=0.5),

        # Color normalization: whenever selected, always applied. This accepts images
        # in [0, 255], requires mean and std in [0, 1] and normalizes to `N(0, 1)`.
        "normalize": partial(
            alb.Normalize, mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, p=1.0
        ),
    }
    # fmt: on

    @classmethod
    def create(cls, name: str, *args, **kwargs) -> Any:
        r"""Create an object by its name, args and kwargs."""

        if "::" in name:
            name, __kwargs = name.split("::")
            _kwargs = eval(__kwargs)
        else:
            _kwargs = {}

        _kwargs.update(kwargs)
        return super().create(name, *args, **_kwargs)

    @classmethod
    def from_config(cls, config: Config):
        r"""Augmentations cannot be created from config, only :meth:`create`."""
        raise NotImplementedError


class PretrainingDatasetFactory(Factory):
    r"""
    Factory to create :class:`~torch.utils.data.Dataset` s for pretraining
    VirTex models. Datasets are created depending on pretraining task used.
    Typically these datasets either provide image-caption pairs, or only images
    from COCO Captions dataset (serialized to an LMDB file).

    As an exception, the dataset for ``multilabel_classification`` provides
    COCO images and labels of their bounding box annotations.

    Possible choices: ``{"bicaptioning", "captioning", "masked_lm",
    "token_classification", "multilabel_classification"}``.
    """

    PRODUCTS: Dict[str, Callable] = {
        "virtex": vdata.CaptioningDataset,
        "bicaptioning": vdata.CaptioningDataset,
        "captioning": vdata.CaptioningDataset,
        "masked_lm": vdata.MaskedLmDataset,
        "token_classification": vdata.TokenClassificationDataset,
        "multilabel_classification": vdata.MultiLabelClassificationDataset,
    }

    @classmethod
    def from_config(cls, config: Config, split: str = "train"):
        r"""
        Create a dataset directly from config. Names in this factory match with
        names in :class:`PretrainingModelFactory` because both use same config
        parameter ``MODEL.NAME`` to create objects.

        Args:
            config: Config object with all the parameters.
            split: Which dataset split to load. One of ``{"train", "val"}``.
        """

        _C = config
        # Every dataset needs these two args.
        kwargs = {"data_root": _C.DATA.ROOT, "split": split}

        # Create a list of image transformations based on transform names.
        image_transform_list: List[Callable] = []

        for name in getattr(_C.DATA, f"IMAGE_TRANSFORM_{split.upper()}"):
            # Pass dimensions if cropping / resizing, else rely on the defaults
            # as per `ImageTransformsFactory`.
            if "resize" in name or "crop" in name:
                image_transform_list.append(
                    ImageTransformsFactory.create(name, _C.DATA.IMAGE_CROP_SIZE)
                )
            else:
                image_transform_list.append(ImageTransformsFactory.create(name))

        kwargs["image_transform"] = alb.Compose(image_transform_list)

        # Add dataset specific kwargs.
        if _C.MODEL.NAME != "multilabel_classification":
            tokenizer = TokenizerFactory.from_config(_C)
            kwargs.update(
                tokenizer=tokenizer,
                max_caption_length=_C.DATA.MAX_CAPTION_LENGTH,
            )

        if _C.MODEL.NAME == "masked_lm":
            kwargs.update(
                mask_proportion=_C.DATA.MASKED_LM.MASK_PROPORTION,
                mask_probability=_C.DATA.MASKED_LM.MASK_PROBABILITY,
                replace_probability=_C.DATA.MASKED_LM.REPLACE_PROBABILITY,
            )

        # Dataset names match with model names (and ofcourse pretext names).
        return cls.create(_C.MODEL.NAME, **kwargs)


class DownstreamDatasetFactory(Factory):
    r"""
    Factory to create :class:`~torch.utils.data.Dataset` s for evaluating
    VirTex models on downstream tasks.

    Possible choices: ``{"datasets/VOC2007", "datasets/imagenet"}``.
    """

    PRODUCTS: Dict[str, Callable] = {
        "datasets/VOC2007": vdata.VOC07ClassificationDataset,
        "datasets/imagenet": vdata.ImageNetDataset,
        "datasets/inaturalist": vdata.INaturalist2018Dataset,
    }

    @classmethod
    def from_config(cls, config: Config, split: str = "train"):
        r"""
        Create a dataset directly from config. Names in this factory are paths
        of dataset directories (relative to the project directory), because
        config parameter ``DATA.ROOT`` is used to create objects.

        Args:
            config: Config object with all the parameters.
            split: Which dataset split to load. One of ``{"trainval", "test"}``
                for VOC2007, or one of ``{"train", "val"}`` for ImageNet.
        """

        _C = config
        # Every dataset needs these two args.
        kwargs = {"data_root": _C.DATA.ROOT, "split": split}

        # For VOC2007, `IMAGE_TRANSFORM_TRAIN` is used for "trainval" split and
        # `IMAGE_TRANSFORM_VAL` is used fo "test" split.
        image_transform_names: List[str] = list(
            _C.DATA.IMAGE_TRANSFORM_TRAIN
            if "train" in split
            else _C.DATA.IMAGE_TRANSFORM_VAL
        )
        # Create a list of image transformations based on names.
        image_transform_list: List[Callable] = []

        for name in image_transform_names:
            # Pass dimensions for resize/crop, else rely on the defaults.
            if name.split("::")[0] in {"random_resized_crop", "center_crop", "global_resize"}:
                transform = ImageTransformsFactory.create(name, 224)
            elif name.split("::")[0] in {"smallest_resize"}:
                transform = ImageTransformsFactory.create(name, 256)
            else:
                transform = ImageTransformsFactory.create(name)

            image_transform_list.append(transform)

        kwargs["image_transform"] = alb.Compose(image_transform_list)

        return cls.create(_C.DATA.ROOT, **kwargs)


class VisualBackboneFactory(Factory):
    r"""
    Factory to create :mod:`~virtex.modules.visual_backbones`. This factory
    supports any ResNet-like model from
    `Torchvision <https://pytorch.org/docs/stable/torchvision/models.html>`_.
    Use the method name for model as in torchvision, for example,
    ``torchvision::resnet50``, ``torchvision::wide_resnet50_2`` etc.

    Possible choices: ``{"torchvision"}``.
    """

    PRODUCTS: Dict[str, Callable] = {
        "torchvision": visual_backbones.TorchvisionVisualBackbone,
    }

    @classmethod
    def from_config(cls, config: Config) -> visual_backbones.VisualBackbone:
        r"""
        Create a visual backbone directly from config.

        Args:
            config: Config object with all the parameters.
        """

        _C = config
        kwargs = {"visual_feature_size": _C.MODEL.VISUAL.FEATURE_SIZE}

        if "torchvision" in _C.MODEL.VISUAL.NAME:
            # Check the name for models from torchvision.
            cnn_name = _C.MODEL.VISUAL.NAME.split("::")[-1]
            kwargs["pretrained"] = _C.MODEL.VISUAL.PRETRAINED
            kwargs["frozen"] = _C.MODEL.VISUAL.FROZEN

            return cls.create("torchvision", cnn_name, **kwargs)
        else:
            return cls.create(_C.MODEL.VISUAL.NAME, **kwargs)


class TextualHeadFactory(Factory):
    r"""
    Factory to create :mod:`~virtex.modules.textual_heads`. Architectural
    hyperparameters for transformers can be specified as ``name::*``.
    For example, ``transdec_postnorm::L1_H1024_A16_F4096`` would create a
    transformer textual head with ``L = 1`` layers, ``H = 1024`` hidden size,
    ``A = 16`` attention heads, and ``F = 4096`` size of feedforward layers.

    Textual head should be ``"none"`` for pretraining tasks which do not
    involve language modeling, such as ``"token_classification"``.

    Possible choices: ``{"transdec_postnorm", "transdec_prenorm", "none"}``.
    """

    PRODUCTS: Dict[str, Callable] = {
        "transdec_prenorm": partial(
            textual_heads.TransformerDecoderTextualHead, norm_first=True
        ),
        "transdec_postnorm": partial(
            textual_heads.TransformerDecoderTextualHead, norm_first=False
        ),
        "none": textual_heads.LinearTextualHead,
    }

    @classmethod
    def from_config(cls, config: Config) -> nn.Module:
        r"""
        Create a textual head directly from config.

        Args:
            config: Config object with all the parameters.
        """

        _C = config
        name = _C.MODEL.TEXTUAL.NAME
        kwargs = {
            "visual_feature_size": _C.MODEL.VISUAL.FEATURE_SIZE,
            "vocab_size": _C.DATA.VOCAB_SIZE,
        }

        if "trans" in _C.MODEL.TEXTUAL.NAME:
            # Get architectural hyper-params as per name by matching regex.
            name, architecture = name.split("::")
            architecture = re.match(r"L(\d+)_H(\d+)_A(\d+)_F(\d+)", architecture)

            num_layers = int(architecture.group(1))
            hidden_size = int(architecture.group(2))
            attention_heads = int(architecture.group(3))
            feedforward_size = int(architecture.group(4))

            # Mask the future tokens for autoregressive captioning.
            mask_future = _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning"}

            kwargs.update(
                hidden_size=hidden_size,
                num_layers=num_layers,
                attention_heads=attention_heads,
                feedforward_size=feedforward_size,
                dropout=_C.MODEL.TEXTUAL.DROPOUT,
                mask_future_positions=mask_future,
                max_caption_length=_C.DATA.MAX_CAPTION_LENGTH,
                padding_idx=_C.DATA.UNK_INDEX,
            )
        return cls.create(name, **kwargs)


class PretrainingModelFactory(Factory):
    r"""
    Factory to create :mod:`~virtex.models` for different pretraining tasks.

    Possible choices: ``{"bicaptioning", "captioning", "masked_lm",
    "token_classification", "multilabel_classification"}``.
    """

    PRODUCTS: Dict[str, Callable] = {
        # First two are basically the same. Added for shorthand notation.
        "virtex": vmodels.VirTexModel,
        "bicaptioning": vmodels.BidirectionalCaptioningModel,
        "captioning": vmodels.ForwardCaptioningModel,
        "masked_lm": vmodels.MaskedLMModel,
        "token_classification": vmodels.TokenClassificationModel,
        "multilabel_classification": vmodels.MultiLabelClassificationModel,
    }

    @classmethod
    def from_config(cls, config: Config) -> nn.Module:
        r"""
        Create a model directly from config.

        Args:
            config: Config object with all the parameters.
        """

        _C = config

        # Build visual and textual streams based on config.
        visual = VisualBackboneFactory.from_config(_C)
        textual = TextualHeadFactory.from_config(_C)

        # Add model specific kwargs. Refer call signatures of specific models
        # for matching kwargs here.
        if _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning"}:
            kwargs = {
                "sos_index": _C.DATA.SOS_INDEX,
                "eos_index": _C.DATA.EOS_INDEX,
                "decoder": CaptionDecoderFactory.from_config(_C),
            }

        elif _C.MODEL.NAME == "token_classification":
            kwargs = {
                "ignore_indices": [
                    _C.DATA.UNK_INDEX,
                    _C.DATA.SOS_INDEX,
                    _C.DATA.EOS_INDEX,
                    _C.DATA.MASK_INDEX,
                ]
            }
        elif _C.MODEL.NAME == "multilabel_classification":
            kwargs = {"ignore_indices": [0]}  # background index
        else:
            kwargs = {}

        return cls.create(_C.MODEL.NAME, visual, textual, **kwargs)


class CaptionDecoderFactory(Factory):
    r"""
    Factory to create decoders from predicting captions from VirTex model.

    Possible choices: ``{"beam_search", "nucleus_sampling"}``.
    """

    PRODUCTS: Dict[str, Callable] = {
        "beam_search": AutoRegressiveBeamSearch,
        "nucleus_sampling": AutoRegressiveNucleusSampling,
    }

    @classmethod
    def from_config(cls, config: Config) -> nn.Module:
        r"""
        Create a model directly from config.

        Args:
            config: Config object with all the parameters.
        """

        _C = config
        kwargs = {
            "eos_index": _C.DATA.EOS_INDEX,
            "max_steps": _C.MODEL.DECODER.MAX_DECODING_STEPS,
        }
        if _C.MODEL.DECODER.NAME == "beam_search":
            kwargs["beam_size"] = _C.MODEL.DECODER.BEAM_SIZE
        elif _C.MODEL.DECODER.NAME == "nucleus_sampling":
            kwargs["nucleus_size"] = _C.MODEL.DECODER.NUCLEUS_SIZE

        return cls.create(_C.MODEL.DECODER.NAME, **kwargs)
        
        
class OptimizerFactory(Factory):
    r"""Factory to create optimizers. Possible choices: ``{"sgd", "adamw"}``."""

    PRODUCTS: Dict[str, Callable] = {"sgd": optim.SGD, "adamw": optim.AdamW}

    @classmethod
    def from_config(
        cls, config: Config, named_parameters: Iterable[Any]
    ) -> optim.Optimizer:
        r"""
        Create an optimizer directly from config.

        Args:
            config: Config object with all the parameters.
            named_parameters: Named parameters of model (retrieved by
                ``model.named_parameters()``) for the optimizer. We use named
                parameters to set different LR and turn off weight decay for
                certain parameters based on their names.
        """

        _C = config

        # Set different learning rate for CNN and rest of the model during
        # pretraining. This doesn't matter for downstream evaluation because
        # there are no modules with "cnn" in their name.
        # Also turn off weight decay for layer norm and bias in textual stream.
        param_groups = []
        for name, param in named_parameters:
            wd = 0.0 if re.match(_C.OPTIM.NO_DECAY, name) else _C.OPTIM.WEIGHT_DECAY
            lr = _C.OPTIM.CNN_LR if "cnn" in name else _C.OPTIM.LR
            param_groups.append({"params": [param], "lr": lr, "weight_decay": wd})

        if _C.OPTIM.OPTIMIZER_NAME == "sgd":
            kwargs = {"momentum": _C.OPTIM.SGD_MOMENTUM}
        else:
            kwargs = {}

        optimizer = cls.create(_C.OPTIM.OPTIMIZER_NAME, param_groups, **kwargs)
        if _C.OPTIM.LOOKAHEAD.USE:
            optimizer = Lookahead(
                optimizer, k=_C.OPTIM.LOOKAHEAD.STEPS, alpha=_C.OPTIM.LOOKAHEAD.ALPHA
            )
        return optimizer


class LRSchedulerFactory(Factory):
    r"""
    Factory to create LR schedulers. All schedulers have a built-in LR warmup
    schedule before actual LR scheduling (decay) starts.

    Possible choices: ``{"none", "multistep", "linear", "cosine"}``.
    """

    PRODUCTS: Dict[str, Callable] = {
        "none": lr_scheduler.LinearWarmupNoDecayLR,
        "multistep": lr_scheduler.LinearWarmupMultiStepLR,
        "linear": lr_scheduler.LinearWarmupLinearDecayLR,
        "cosine": lr_scheduler.LinearWarmupCosineAnnealingLR,
    }

    @classmethod
    def from_config(
        cls, config: Config, optimizer: optim.Optimizer
    ) -> optim.lr_scheduler.LambdaLR:
        r"""
        Create an LR scheduler directly from config.

        Args:
            config: Config object with all the parameters.
            optimizer: Optimizer on which LR scheduling would be performed.
        """

        _C = config
        kwargs = {
            "total_steps": _C.OPTIM.NUM_ITERATIONS,
            "warmup_steps": _C.OPTIM.WARMUP_STEPS,
        }
        # Multistep LR requires multiplicative factor and milestones.
        if _C.OPTIM.LR_DECAY_NAME == "multistep":
            kwargs.update(gamma=_C.OPTIM.LR_GAMMA, milestones=_C.OPTIM.LR_STEPS)

        return cls.create(_C.OPTIM.LR_DECAY_NAME, optimizer, **kwargs)


================================================
FILE: virtex/model_zoo/__init__.py
================================================
from .model_zoo import get

__all__ = ["get"]


================================================
FILE: virtex/model_zoo/model_zoo.py
================================================
r"""
A utility module to easily load common VirTex models (optionally with pretrained
weights) using a single line of code.

Get our full best performing VirTex model (with pretrained weights as):

>>> import virtex.model_zoo as mz
>>> model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True)

Any config available in ``configs/`` directory under project root can be
specified here, although this command need not be executed from project root.
For more details on available models, refer :doc:`usage/model_zoo`.

Part of this code is adapted from Detectron2's model zoo; which was originally
implemented by the developers of this codebase, with reviews and further
changes by Detectron2 developers.
"""
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import pkg_resources

from fvcore.common.download import download
import torch

from virtex.config import Config
from virtex.factories import PretrainingModelFactory
from virtex.utils.checkpointing import CheckpointManager


class _ModelZooUrls:
    r"""Mapping from config names to URL suffixes of pretrained weights."""

    URL_PREFIX = "https://www.dropbox.com/s"

    CONFIG_PATH_TO_DB_ID = {

        # Pretraining Task Ablations
        "task_ablations/bicaptioning_R_50_L1_H2048.yaml": "mbeeso8wyieq8wy",
        "task_ablations/captioning_R_50_L1_H2048.yaml": "r6zen9k43m5oo58",
        "task_ablations/token_classification_R_50.yaml": "o4p9lki505r0mef",
        "task_ablations/multilabel_classification_R_50.yaml": "hbspp3jv3u8h3bc",
        "task_ablations/masked_lm_R_50_L1_H2048.yaml": "ldzrk6vem4mg6bl",

        # Width Ablations
        "width_ablations/bicaptioning_R_50_L1_H512.yaml": "o9fr69jjqfn8a65",
        "width_ablations/bicaptioning_R_50_L1_H768.yaml": "1zxglqrrbfufv9d",
        "width_ablations/bicaptioning_R_50_L1_H1024.yaml": "pdat4tvhnqxel64",
        "width_ablations/bicaptioning_R_50_L1_H2048.yaml": "mbeeso8wyieq8wy",

        # Depth Ablations
        "depth_ablations/bicaptioning_R_50_L1_H1024.yaml": "pdat4tvhnqxel64",
        "depth_ablations/bicaptioning_R_50_L2_H1024.yaml": "ft1vtt4okirzjgo",
        "depth_ablations/bicaptioning_R_50_L3_H1024.yaml": "5ldo1rcsnrshmjr",
        "depth_ablations/bicaptioning_R_50_L4_H1024.yaml": "zgiit2wcluuq3xh",

        # Backbone Ablations
        "backbone_ablations/bicaptioning_R_50_L1_H1024.yaml": "pdat4tvhnqxel64",
        "backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml": "5o198ux709r6376",
        "backbone_ablations/bicaptioning_R_101_L1_H1024.yaml": "bb74jubt68cpn80",
    }


def get(config_path: str, pretrained: bool = False):
    r"""
    Get a model specified by relative path under Detectron2's official
    ``configs/`` directory.

    Args:
        config_path: Name of config file relative to ``configs/`` directory
            under project root. (E.g. ``width_ablations/bicaptioning_R_50_L1_H2048.yaml``)
        pretrained: If ``True``, will initialize the model with the pretrained
            weights. If ``False``, the weights will be initialized randomly.
    """

    # Get the original path to config file (shipped with inside the package).
    _pkg_config_path = pkg_resources.resource_filename(
        "virtex.model_zoo", os.path.join("configs", config_path)
    )
    if not os.path.exists(_pkg_config_path):
        raise RuntimeError("{} not available in Model Zoo!".format(config_path))

    _C = Config(_pkg_config_path)
    model = PretrainingModelFactory.from_config(_C)

    if pretrained:
        # Get URL for the checkpoint for this config path.
        if config_path in _ModelZooUrls.CONFIG_PATH_TO_DB_ID:

            dropbox_id = _ModelZooUrls.CONFIG_PATH_TO_DB_ID[config_path]
            filename = os.path.basename(config_path).replace(".yaml", ".pth")

            checkpoint_url = f"{_ModelZooUrls.URL_PREFIX}/{dropbox_id}/{filename}?dl=1"
        else:
            raise RuntimeError("{} not available in Model Zoo!".format(config_path))

        # Download the pretrained model weights and save with a sensible name.
        # This will be downloaded only if it does not exist.
        checkpoint_path = download(
            checkpoint_url,
            dir=os.path.expanduser("~/.torch/virtex_cache"),
            filename=os.path.basename(config_path).replace(".yaml", ".pth")
        )
        CheckpointManager(model=model).load(checkpoint_path)

    return model


================================================
FILE: virtex/models/__init__.py
================================================
from .captioning import (
    ForwardCaptioningModel,
    BidirectionalCaptioningModel,
    VirTexModel
)
from .masked_lm import MaskedLMModel
from .classification import (
    MultiLabelClassificationModel,
    TokenClassificationModel,
)


__all__ = [
    "VirTexModel",
    "BidirectionalCaptioningModel",
    "ForwardCaptioningModel",
    "MaskedLMModel",
    "MultiLabelClassificationModel",
    "TokenClassificationModel",
]


================================================
FILE: virtex/models/captioning.py
================================================
import copy
import functools
from typing import Any, Dict

import torch
from torch import nn

from virtex.data.tokenizers import SentencePieceBPETokenizer
from virtex.modules.textual_heads import TextualHead
from virtex.modules.visual_backbones import VisualBackbone


class CaptioningModel(nn.Module):
    r"""
    A model to perform image captioning (in both forward and backward directions
    independently, only in forward direction). It is composed of a
    :class:`~virtex.modules.visual_backbones.VisualBackbone` and a
    :class:`~virtex.modules.textual_heads.TextualHead` on top of it.

    During training, it maximizes the likelihood of ground truth caption
    conditioned on image features. During inference, it predicts a caption for
    an input image through beam search decoding.

    Args:
        visual: A :class:`~virtex.modules.visual_backbones.VisualBackbone` which
            computes visual features from an input image.
        textual: A :class:`~virtex.modules.textual_heads.TextualHead` which
            makes final predictions conditioned on visual features.
        sos_index: The index of the start token (``[SOS]``) in vocabulary.
        eos_index: The index of the end token (``[EOS]``) in vocabulary.
        caption_backward: Whether to *also* perform captioning in backward
            direction. Default is ``False`` -- only forward captioning is
            performed. When ``True``, a clone of textual head is created, which
            does not share weights with "forward" model except input/output embeddings.
        decoder: A :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch`
            or :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`
            object for decoding captions during inference (unused during training).
    """

    def __init__(
        self,
        visual: VisualBackbone,
        textual: TextualHead,
        caption_backward: bool = False,
        sos_index: int = 1,
        eos_index: int = 2,
        decoder: Any = None,
    ):
        super().__init__()
        self.visual = visual
        self.textual = textual
        self.padding_idx = self.textual.padding_idx
        self.caption_backward = caption_backward

        # Clone the textual module for backward direction if doing captioning
        # in both directions (separately).
        if self.caption_backward:
            self.backward_textual = copy.deepcopy(self.textual)

            # Share weights for visual projection, and input/output embeddings.
            self.backward_textual.visual_projection = self.textual.visual_projection
            self.backward_textual.embedding = self.textual.embedding
            self.backward_textual.output = self.textual.output

        # These boundary indices are needed for beam search.
        self.sos_index = sos_index
        self.eos_index = eos_index
        self.decoder = decoder
        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx)

    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        r"""
        Given a batch of images and captions, compute log likelihood loss per
        caption token during training. During inference (with images), predict
        a caption through either beam search decoding or nucleus sampling.

        Args:
            batch: A batch of images and (optionally) ground truth caption tokens.
                Possible set of keys: ``{"image_id", "image", "caption_tokens",
                "noitpac_tokens", "caption_lengths"}``.

        Returns:
            A dict with the following structure, containing loss for optimization,
            loss components to log directly to tensorboard, and optionally
            predictions.

            .. code-block::

                {
                    "loss": torch.Tensor,
                    "loss_components": {
                        "captioning_forward": torch.Tensor,
                        "captioning_backward": torch.Tensor, (optional)
                    },
                    "predictions": torch.Tensor
                }
        """

        # shape: (batch_size, channels, height, width)
        visual_features = self.visual(batch["image"])
        batch_size = visual_features.size(0)

        if "caption_tokens" in batch:
            caption_tokens = batch["caption_tokens"]
            caption_lengths = batch["caption_lengths"]

            # shape: (batch_size, max_caption_length, vocab_size)
            output_logits = self.textual(
                visual_features, caption_tokens, caption_lengths
            )
            loss = self.loss(
                output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size),
                caption_tokens[:, 1:].contiguous().view(-1),
            )
            output_dict: Dict[str, Any] = {
                "loss": loss,
                # Single scalar per batch for logging in training script.
                "loss_components": {"captioning_forward": loss.clone().detach()},
            }
            # Do captioning in backward direction if specified.
            if self.caption_backward:
                backward_caption_tokens = batch["noitpac_tokens"]

                backward_output_logits = self.backward_textual(
                    visual_features, backward_caption_tokens, caption_lengths
                )
                backward_loss = self.loss(
                    backward_output_logits[:, :-1]
                    .contiguous()
                    .view(-1, self.textual.vocab_size),
                    backward_caption_tokens[:, 1:].contiguous().view(-1),
                )
                output_dict["loss"] += backward_loss

                # Single scalar per batch for logging in training script.
                output_dict["loss_components"].update(
                    captioning_backward=backward_loss.clone().detach()
                )

            if not self.training:
                # During validation (while pretraining), get best prediction
                # at every timestep.
                output_dict["predictions"] = torch.argmax(output_logits, dim=-1)
        else:
            if self.decoder is None:
                raise ValueError("Decoder for predicting captions is missing!")

            # During inference, get beam search predictions for forward
            # model. Predictions from forward transformer will be shifted
            # right by one timestep.
            start_predictions = visual_features.new_full(
                (batch_size,), self.sos_index
            ).long()
            # Add image features as a default argument to match callable
            # signature accepted by beam search class (partial captions only).
            decoding_step = functools.partial(self.decoding_step, visual_features)

            predicted_caption, _ = self.decoder.search(
                start_predictions,
Download .txt
gitextract_cto194sv/

├── .gitignore
├── CHANGELOG.md
├── LICENSE
├── README.md
├── configs/
│   ├── _base_bicaptioning_R_50_L1_H1024.yaml
│   ├── backbone_ablations/
│   │   ├── bicaptioning_R_101_L1_H1024.yaml
│   │   ├── bicaptioning_R_50W2X_L1_H1024.yaml
│   │   └── bicaptioning_R_50_L1_H1024.yaml
│   ├── depth_ablations/
│   │   ├── bicaptioning_R_50_L1_H1024.yaml
│   │   ├── bicaptioning_R_50_L2_H1024.yaml
│   │   ├── bicaptioning_R_50_L3_H1024.yaml
│   │   └── bicaptioning_R_50_L4_H1024.yaml
│   ├── detectron2/
│   │   ├── _base_faster_rcnn_R_50_C4_BN.yaml
│   │   ├── _base_mask_rcnn_R_50_FPN.yaml
│   │   ├── coco_segm_default_init_2x.yaml
│   │   ├── lvis_segm_default_init_2x.yaml
│   │   ├── lvis_segm_imagenet_init_2x.yaml
│   │   └── voc_det_default_init_24k.yaml
│   ├── downstream/
│   │   ├── imagenet_clf.yaml
│   │   ├── inaturalist_clf.yaml
│   │   └── voc07_clf.yaml
│   ├── task_ablations/
│   │   ├── bicaptioning_R_50_L1_H2048.yaml
│   │   ├── captioning_R_50_L1_H2048.yaml
│   │   ├── masked_lm_R_50_L1_H2048.yaml
│   │   ├── multilabel_classification_R_50.yaml
│   │   └── token_classification_R_50.yaml
│   └── width_ablations/
│       ├── bicaptioning_R_50_L1_H1024.yaml
│       ├── bicaptioning_R_50_L1_H2048.yaml
│       ├── bicaptioning_R_50_L1_H512.yaml
│       └── bicaptioning_R_50_L1_H768.yaml
├── docs/
│   ├── Makefile
│   ├── _templates/
│   │   └── layout.html
│   ├── conf.py
│   ├── index.rst
│   └── virtex/
│       ├── config.rst
│       ├── data.datasets.rst
│       ├── data.rst
│       ├── data.tokenizers.rst
│       ├── data.transforms.rst
│       ├── factories.rst
│       ├── model_zoo.rst
│       ├── models.rst
│       ├── modules.embedding.rst
│       ├── modules.rst
│       ├── modules.textual_heads.rst
│       ├── modules.visual_backbones.rst
│       ├── optim.lookahead.rst
│       ├── optim.lr_scheduler.rst
│       ├── optim.rst
│       ├── usage/
│       │   ├── downstream.rst
│       │   ├── model_zoo.rst
│       │   ├── pretrain.rst
│       │   └── setup_dependencies.rst
│       ├── utils.beam_search.rst
│       ├── utils.checkpointing.rst
│       ├── utils.common.rst
│       ├── utils.distributed.rst
│       ├── utils.metrics.rst
│       ├── utils.rst
│       └── utils.timer.rst
├── hubconf.py
├── requirements.txt
├── scripts/
│   ├── build_vocabulary.py
│   ├── clf_linear.py
│   ├── clf_voc07.py
│   ├── eval_captioning.py
│   ├── eval_detectron2.py
│   └── pretrain_virtex.py
├── setup.py
└── virtex/
    ├── __init__.py
    ├── config.py
    ├── data/
    │   ├── __init__.py
    │   ├── datasets/
    │   │   ├── captioning.py
    │   │   ├── classification.py
    │   │   ├── coco_captions.py
    │   │   ├── downstream.py
    │   │   └── masked_lm.py
    │   ├── tokenizers.py
    │   └── transforms.py
    ├── factories.py
    ├── model_zoo/
    │   ├── __init__.py
    │   └── model_zoo.py
    ├── models/
    │   ├── __init__.py
    │   ├── captioning.py
    │   ├── classification.py
    │   └── masked_lm.py
    ├── modules/
    │   ├── embedding.py
    │   ├── textual_heads.py
    │   └── visual_backbones.py
    ├── optim/
    │   ├── __init__.py
    │   ├── lookahead.py
    │   └── lr_scheduler.py
    └── utils/
        ├── beam_search.py
        ├── checkpointing.py
        ├── common.py
        ├── distributed.py
        ├── metrics.py
        ├── nucleus_sampling.py
        └── timer.py
Download .txt
SYMBOL INDEX (215 symbols across 34 files)

FILE: docs/conf.py
  function linkcode_resolve (line 123) | def linkcode_resolve(domain, info):

FILE: hubconf.py
  function resnet50 (line 10) | def resnet50(pretrained: bool = False, **kwargs):

FILE: scripts/build_vocabulary.py
  function _read_captions (line 41) | def _read_captions(annotations_path: str) -> List[str]:

FILE: scripts/clf_linear.py
  function main (line 70) | def main(_A: argparse.Namespace):

FILE: scripts/clf_voc07.py
  function train_test_single_svm (line 56) | def train_test_single_svm(args):
  function main (line 108) | def main(_A: argparse.Namespace):

FILE: scripts/eval_captioning.py
  function main (line 44) | def main(_A: argparse.Namespace):

FILE: scripts/eval_detectron2.py
  class Res5ROIHeadsExtraNorm (line 82) | class Res5ROIHeadsExtraNorm(Res5ROIHeads):
    method _build_res5_block (line 88) | def _build_res5_block(self, cfg):
  function build_detectron2_config (line 95) | def build_detectron2_config(_C: Config, _A: argparse.Namespace):
  class DownstreamTrainer (line 119) | class DownstreamTrainer(DefaultTrainer):
    method __init__ (line 131) | def __init__(self, cfg, weights):
    method build_evaluator (line 153) | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
    method test (line 165) | def test(self, cfg=None, model=None, evaluators=None):
  function main (line 177) | def main(_A: argparse.Namespace):

FILE: scripts/pretrain_virtex.py
  function main (line 44) | def main(_A: argparse.Namespace):

FILE: setup.py
  function get_model_zoo_configs (line 9) | def get_model_zoo_configs() -> List[str]:

FILE: virtex/config.py
  class Config (line 6) | class Config:
    method __init__ (line 36) | def __init__(
    method dump (line 221) | def dump(self, file_path: str):
    method __getattr__ (line 229) | def __getattr__(self, attr: str):
    method __str__ (line 232) | def __str__(self):
    method __repr__ (line 235) | def __repr__(self):

FILE: virtex/data/datasets/captioning.py
  class CaptioningDataset (line 13) | class CaptioningDataset(Dataset):
    method __init__ (line 30) | def __init__(
    method __len__ (line 48) | def __len__(self):
    method __getitem__ (line 51) | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
    method collate_fn (line 79) | def collate_fn(

FILE: virtex/data/datasets/classification.py
  class TokenClassificationDataset (line 19) | class TokenClassificationDataset(Dataset):
    method __init__ (line 35) | def __init__(
    method __len__ (line 52) | def __len__(self):
    method __getitem__ (line 55) | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
    method collate_fn (line 81) | def collate_fn(
  class MultiLabelClassificationDataset (line 97) | class MultiLabelClassificationDataset(Dataset):
    method __init__ (line 110) | def __init__(
    method __len__ (line 154) | def __len__(self):
    method __getitem__ (line 157) | def __getitem__(self, idx: int):
    method collate_fn (line 176) | def collate_fn(

FILE: virtex/data/datasets/coco_captions.py
  class CocoCaptionsDataset (line 11) | class CocoCaptionsDataset(Dataset):
    method __init__ (line 22) | def __init__(self, data_root: str, split: str):
    method __len__ (line 53) | def __len__(self):
    method __getitem__ (line 56) | def __getitem__(self, idx: int):

FILE: virtex/data/datasets/downstream.py
  class ImageNetDataset (line 16) | class ImageNetDataset(ImageNet):
    method __init__ (line 29) | def __init__(
    method __getitem__ (line 38) | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
    method collate_fn (line 50) | def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch...
  class INaturalist2018Dataset (line 57) | class INaturalist2018Dataset(Dataset):
    method __init__ (line 69) | def __init__(
    method __len__ (line 92) | def __len__(self):
    method __getitem__ (line 95) | def __getitem__(self, idx: int):
    method collate_fn (line 111) | def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch...
  class VOC07ClassificationDataset (line 118) | class VOC07ClassificationDataset(Dataset):
    method __init__ (line 131) | def __init__(
    method __len__ (line 177) | def __len__(self):
    method __getitem__ (line 180) | def __getitem__(self, idx: int):
    method collate_fn (line 195) | def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch...
  class ImageDirectoryDataset (line 202) | class ImageDirectoryDataset(Dataset):
    method __init__ (line 214) | def __init__(
    method __len__ (line 220) | def __len__(self):
    method __getitem__ (line 223) | def __getitem__(self, idx: int):

FILE: virtex/data/datasets/masked_lm.py
  class MaskedLmDataset (line 15) | class MaskedLmDataset(Dataset):
    method __init__ (line 16) | def __init__(
    method __len__ (line 43) | def __len__(self):
    method __getitem__ (line 46) | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
    method collate_fn (line 101) | def collate_fn(
    method _random_token_index (line 124) | def _random_token_index(self) -> int:

FILE: virtex/data/tokenizers.py
  class SentencePieceBPETokenizer (line 6) | class SentencePieceBPETokenizer:
    method __init__ (line 16) | def __init__(self, model_path: str):
    method __getstate__ (line 23) | def __getstate__(self):
    method __setstate__ (line 32) | def __setstate__(self, state_dict: Dict[str, Any]):
    method get_vocab_size (line 38) | def get_vocab_size(self) -> int:
    method token_to_id (line 42) | def token_to_id(self, token: str) -> int:
    method id_to_token (line 48) | def id_to_token(self, token_id: int) -> str:
    method encode (line 52) | def encode(self, text: str) -> List[int]:
    method decode (line 56) | def decode(self, token_ids: List[int]) -> str:

FILE: virtex/data/transforms.py
  class HorizontalFlip (line 5) | class HorizontalFlip(alb.BasicTransform):
    method targets (line 25) | def targets(self):
    method apply (line 28) | def apply(self, img, **params):
    method apply_to_caption (line 31) | def apply_to_caption(self, caption, **params):
  class RandomResizedSquareCrop (line 40) | class RandomResizedSquareCrop(alb.RandomResizedCrop):
    method __init__ (line 49) | def __init__(self, size: int, *args, **kwargs):
  class CenterSquareCrop (line 53) | class CenterSquareCrop(alb.CenterCrop):
    method __init__ (line 62) | def __init__(self, size: int, *args, **kwargs):
  class SquareResize (line 66) | class SquareResize(alb.Resize):
    method __init__ (line 75) | def __init__(self, size: int, *args, **kwargs):

FILE: virtex/factories.py
  class Factory (line 40) | class Factory:
    method __init__ (line 60) | def __init__(self):
    method create (line 68) | def create(cls, name: str, *args, **kwargs) -> Any:
    method from_config (line 76) | def from_config(cls, config: Config) -> Any:
  class TokenizerFactory (line 81) | class TokenizerFactory(Factory):
    method from_config (line 94) | def from_config(cls, config: Config) -> SentencePieceBPETokenizer:
  class ImageTransformsFactory (line 111) | class ImageTransformsFactory(Factory):
    method create (line 158) | def create(cls, name: str, *args, **kwargs) -> Any:
    method from_config (line 171) | def from_config(cls, config: Config):
  class PretrainingDatasetFactory (line 176) | class PretrainingDatasetFactory(Factory):
    method from_config (line 200) | def from_config(cls, config: Config, split: str = "train"):
  class DownstreamDatasetFactory (line 249) | class DownstreamDatasetFactory(Factory):
    method from_config (line 264) | def from_config(cls, config: Config, split: str = "train"):
  class VisualBackboneFactory (line 306) | class VisualBackboneFactory(Factory):
    method from_config (line 322) | def from_config(cls, config: Config) -> visual_backbones.VisualBackbone:
  class TextualHeadFactory (line 344) | class TextualHeadFactory(Factory):
    method from_config (line 369) | def from_config(cls, config: Config) -> nn.Module:
  class PretrainingModelFactory (line 410) | class PretrainingModelFactory(Factory):
    method from_config (line 429) | def from_config(cls, config: Config) -> nn.Module:
  class CaptionDecoderFactory (line 469) | class CaptionDecoderFactory(Factory):
    method from_config (line 482) | def from_config(cls, config: Config) -> nn.Module:
  class OptimizerFactory (line 503) | class OptimizerFactory(Factory):
    method from_config (line 509) | def from_config(
  class LRSchedulerFactory (line 548) | class LRSchedulerFactory(Factory):
    method from_config (line 564) | def from_config(

FILE: virtex/model_zoo/model_zoo.py
  class _ModelZooUrls (line 30) | class _ModelZooUrls:
  function get (line 63) | def get(config_path: str, pretrained: bool = False):

FILE: virtex/models/captioning.py
  class CaptioningModel (line 13) | class CaptioningModel(nn.Module):
    method __init__ (line 40) | def __init__(
    method forward (line 71) | def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    method decoding_step (line 165) | def decoding_step(
    method log_predictions (line 215) | def log_predictions(
  class ForwardCaptioningModel (line 234) | class ForwardCaptioningModel(CaptioningModel):
    method __init__ (line 240) | def __init__(
  class BidirectionalCaptioningModel (line 258) | class BidirectionalCaptioningModel(CaptioningModel):
    method __init__ (line 264) | def __init__(

FILE: virtex/models/classification.py
  class ClassificationModel (line 12) | class ClassificationModel(nn.Module):
    method __init__ (line 35) | def __init__(
    method forward (line 43) | def forward(self, batch: Dict[str, torch.Tensor]):
  class TokenClassificationModel (line 111) | class TokenClassificationModel(ClassificationModel):
    method log_predictions (line 120) | def log_predictions(
  class MultiLabelClassificationModel (line 143) | class MultiLabelClassificationModel(ClassificationModel):
    method log_predictions (line 152) | def log_predictions(

FILE: virtex/models/masked_lm.py
  class MaskedLMModel (line 11) | class MaskedLMModel(nn.Module):
    method __init__ (line 28) | def __init__(self, visual: VisualBackbone, textual: TextualHead):
    method forward (line 35) | def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    method log_predictions (line 88) | def log_predictions(

FILE: virtex/modules/embedding.py
  class WordAndPositionalEmbedding (line 7) | class WordAndPositionalEmbedding(nn.Module):
    method __init__ (line 24) | def __init__(
    method forward (line 46) | def forward(self, tokens: torch.Tensor) -> torch.Tensor:
    method _create_position_indices (line 77) | def _create_position_indices(self, tokens: torch.Tensor):

FILE: virtex/modules/textual_heads.py
  class TextualHead (line 15) | class TextualHead(nn.Module):
    method __init__ (line 29) | def __init__(self, visual_feature_size: int, vocab_size: int, hidden_s...
    method textual_feature_size (line 36) | def textual_feature_size(self):
  class LinearTextualHead (line 46) | class LinearTextualHead(TextualHead):
    method __init__ (line 57) | def __init__(self, visual_feature_size: int, vocab_size: int, **kwargs):
    method forward (line 63) | def forward(
  class TransformerDecoderTextualHead (line 98) | class TransformerDecoderTextualHead(TextualHead):
    method __init__ (line 146) | def __init__(
    method _init_weights (line 203) | def _init_weights(module):
    method forward (line 216) | def forward(
    method make_future_mask (line 282) | def make_future_mask(

FILE: virtex/modules/visual_backbones.py
  class VisualBackbone (line 8) | class VisualBackbone(nn.Module):
    method __init__ (line 15) | def __init__(self, visual_feature_size: int):
  class TorchvisionVisualBackbone (line 20) | class TorchvisionVisualBackbone(VisualBackbone):
    method __init__ (line 34) | def __init__(
    method forward (line 55) | def forward(self, image: torch.Tensor) -> torch.Tensor:
    method detectron2_backbone_state_dict (line 76) | def detectron2_backbone_state_dict(self) -> Dict[str, Any]:

FILE: virtex/optim/lookahead.py
  class Lookahead (line 25) | class Lookahead(Optimizer):
    method __init__ (line 36) | def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.8):
    method __getstate__ (line 52) | def __getstate__(self):
    method param_groups (line 62) | def param_groups(self):
    method zero_grad (line 65) | def zero_grad(self):
    method state_dict (line 69) | def state_dict(self):
    method load_state_dict (line 72) | def load_state_dict(self, state_dict: Dict[str, Any]):
    method step (line 82) | def step(self, closure: Callable = None):
    method load_slow_weights (line 104) | def load_slow_weights(self):
    method restore_fast_weights (line 120) | def restore_fast_weights(self):

FILE: virtex/optim/lr_scheduler.py
  class LinearWarmupNoDecayLR (line 9) | class LinearWarmupNoDecayLR(LambdaLR):
    method __init__ (line 23) | def __init__(
    method _lr_multiplier (line 38) | def _lr_multiplier(self, step: int) -> float:
  class LinearWarmupMultiStepLR (line 43) | class LinearWarmupMultiStepLR(LambdaLR):
    method __init__ (line 64) | def __init__(
    method _lr_multiplier (line 89) | def _lr_multiplier(self, step: int) -> float:
  class LinearWarmupLinearDecayLR (line 101) | class LinearWarmupLinearDecayLR(LambdaLR):
    method __init__ (line 115) | def __init__(
    method _lr_multiplier (line 130) | def _lr_multiplier(self, step: int) -> float:
  class LinearWarmupCosineAnnealingLR (line 141) | class LinearWarmupCosineAnnealingLR(LambdaLR):
    method __init__ (line 159) | def __init__(
    method _lr_multiplier (line 174) | def _lr_multiplier(self, step: int) -> float:

FILE: virtex/utils/beam_search.py
  class AutoRegressiveBeamSearch (line 25) | class AutoRegressiveBeamSearch:
    method __init__ (line 40) | def __init__(
    method search (line 52) | def search(

FILE: virtex/utils/checkpointing.py
  class CheckpointManager (line 12) | class CheckpointManager:
    method __init__ (line 48) | def __init__(
    method step (line 68) | def step(self, iteration: int, metric: Optional[float] = None):
    method _state_dict (line 107) | def _state_dict(self):
    method remove_earliest_checkpoint (line 121) | def remove_earliest_checkpoint(self):
    method load (line 127) | def load(self, checkpoint_path: str):

FILE: virtex/utils/common.py
  function cycle (line 14) | def cycle(dataloader, device, start_iteration: int = 0):
  function common_setup (line 39) | def common_setup(_C: Config, _A: argparse.Namespace, job_type: str = "pr...
  function common_parser (line 102) | def common_parser(description: str = "") -> argparse.ArgumentParser:

FILE: virtex/utils/distributed.py
  function launch (line 15) | def launch(
  function _job_worker (line 82) | def _job_worker(
  function synchronize (line 115) | def synchronize() -> None:
  function get_world_size (line 121) | def get_world_size() -> int:
  function get_rank (line 126) | def get_rank() -> int:
  function is_master_process (line 131) | def is_master_process() -> bool:
  function average_across_processes (line 140) | def average_across_processes(t: Union[torch.Tensor, Dict[str, torch.Tens...
  function gpu_mem_usage (line 163) | def gpu_mem_usage() -> int:

FILE: virtex/utils/metrics.py
  class TopkAccuracy (line 22) | class TopkAccuracy:
    method __init__ (line 35) | def __init__(self, k: int = 1):
    method reset (line 39) | def reset(self):
    method __call__ (line 43) | def __call__(self, predictions: torch.Tensor, ground_truth: torch.Tens...
    method get_result (line 70) | def get_result(self):
  class CocoCaptionsEvaluator (line 75) | class CocoCaptionsEvaluator:
    method __init__ (line 85) | def __init__(self, gt_annotations_path: str):
    method evaluate (line 95) | def evaluate(self, preds: List[Dict[str, Any]]) -> Dict[str, float]:
  function tokenize (line 125) | def tokenize(image_id_to_captions: Dict[int, List[str]]) -> Dict[int, Li...
  function cider (line 177) | def cider(
  function spice (line 267) | def spice(

FILE: virtex/utils/nucleus_sampling.py
  class AutoRegressiveNucleusSampling (line 25) | class AutoRegressiveNucleusSampling:
    method __init__ (line 36) | def __init__(
    method search (line 47) | def search(

FILE: virtex/utils/timer.py
  class Timer (line 5) | class Timer:
    method __init__ (line 17) | def __init__(
    method tic (line 31) | def tic(self) -> None:
    method toc (line 35) | def toc(self) -> None:
    method stats (line 42) | def stats(self) -> str:
    method eta_hhmm (line 50) | def eta_hhmm(self) -> str:
Condensed preview — 99 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (299K chars).
[
  {
    "path": ".gitignore",
    "chars": 683,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "CHANGELOG.md",
    "chars": 4420,
    "preview": "CHANGELOG\n=========\n\nThis CHANGELOG file records changes between different arXiv versions of our paper, and the version "
  },
  {
    "path": "LICENSE",
    "chars": 1057,
    "preview": "Copyright (c) 2020, Karan Desai.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this s"
  },
  {
    "path": "README.md",
    "chars": 3558,
    "preview": "VirTex: Learning Visual Representations from Textual Annotations\n======================================================="
  },
  {
    "path": "configs/_base_bicaptioning_R_50_L1_H1024.yaml",
    "chars": 1371,
    "preview": "# -----------------------------------------------------------------------------\n# Base config: VirTex pretraining for ou"
  },
  {
    "path": "configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml",
    "chars": 104,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  VISUAL:\n    NAME: \"torchvision::resnet101\"\n"
  },
  {
    "path": "configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml",
    "chars": 110,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  VISUAL:\n    NAME: \"torchvision::wide_resnet50_2\"\n"
  },
  {
    "path": "configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml",
    "chars": 51,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n"
  },
  {
    "path": "configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml",
    "chars": 51,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n"
  },
  {
    "path": "configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml",
    "chars": 120,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L2_H1024_A16_F4096\"\n"
  },
  {
    "path": "configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml",
    "chars": 120,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L3_H1024_A16_F4096\"\n"
  },
  {
    "path": "configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml",
    "chars": 120,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L4_H1024_A16_F4096\"\n"
  },
  {
    "path": "configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml",
    "chars": 1315,
    "preview": "# ----------------------------------------------------------------------------\n# Train a Faster R-CNN with ResNet-50 and"
  },
  {
    "path": "configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml",
    "chars": 2049,
    "preview": "# ----------------------------------------------------------------------------\n# Train a Mask R-CNN with ResNet-50 and F"
  },
  {
    "path": "configs/detectron2/coco_segm_default_init_2x.yaml",
    "chars": 685,
    "preview": "# -----------------------------------------------------------------------------\n# Train a Mask R-CNN R50-FPN backbone on"
  },
  {
    "path": "configs/detectron2/lvis_segm_default_init_2x.yaml",
    "chars": 900,
    "preview": "# -----------------------------------------------------------------------------\n# Train a Mask R-CNN R50-FPN backbone on"
  },
  {
    "path": "configs/detectron2/lvis_segm_imagenet_init_2x.yaml",
    "chars": 959,
    "preview": "# -----------------------------------------------------------------------------\n# Train a Mask R-CNN R50-FPN backbone on"
  },
  {
    "path": "configs/detectron2/voc_det_default_init_24k.yaml",
    "chars": 773,
    "preview": "# -----------------------------------------------------------------------------\n# Train a Faster R-CNN with R50-C4 backb"
  },
  {
    "path": "configs/downstream/imagenet_clf.yaml",
    "chars": 603,
    "preview": "RANDOM_SEED: 0\n# Don't need AMP to train a tiny linear layer.\nAMP: false\nCUDNN_BENCHMARK: true\nCUDNN_DETERMINISTIC: fals"
  },
  {
    "path": "configs/downstream/inaturalist_clf.yaml",
    "chars": 649,
    "preview": "RANDOM_SEED: 0\nAMP: true\nCUDNN_BENCHMARK: true\nCUDNN_DETERMINISTIC: false\n\nDATA:\n  ROOT: \"datasets/inaturalist\"\n  IMAGE_"
  },
  {
    "path": "configs/downstream/voc07_clf.yaml",
    "chars": 289,
    "preview": "RANDOM_SEED: 0\nDATA:\n  ROOT: datasets/VOC2007\n  IMAGE_TRANSFORM_TRAIN:\n    - smallest_resize\n    - center_crop\n    - nor"
  },
  {
    "path": "configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml",
    "chars": 120,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H2048_A32_F8192\"\n"
  },
  {
    "path": "configs/task_ablations/captioning_R_50_L1_H2048.yaml",
    "chars": 141,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  NAME: \"captioning\"\n  TEXTUAL:\n    NAME: \"transdec_postnorm:"
  },
  {
    "path": "configs/task_ablations/masked_lm_R_50_L1_H2048.yaml",
    "chars": 140,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  NAME: \"masked_lm\"\n  TEXTUAL:\n    NAME: \"transdec_postnorm::"
  },
  {
    "path": "configs/task_ablations/multilabel_classification_R_50.yaml",
    "chars": 174,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nDATA:\n  VOCAB_SIZE: 81\n\nMODEL:\n  NAME: \"multilabel_classification\"\n "
  },
  {
    "path": "configs/task_ablations/token_classification_R_50.yaml",
    "chars": 145,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  NAME: \"token_classification\"\n  TEXTUAL:\n    NAME: \"none\"\n\nO"
  },
  {
    "path": "configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml",
    "chars": 51,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n"
  },
  {
    "path": "configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml",
    "chars": 120,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H2048_A32_F8192\"\n"
  },
  {
    "path": "configs/width_ablations/bicaptioning_R_50_L1_H512.yaml",
    "chars": 118,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H512_A8_F2048\"\n"
  },
  {
    "path": "configs/width_ablations/bicaptioning_R_50_L1_H768.yaml",
    "chars": 119,
    "preview": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H768_A12_F3072\"\n"
  },
  {
    "path": "docs/Makefile",
    "chars": 594,
    "preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHI"
  },
  {
    "path": "docs/_templates/layout.html",
    "chars": 624,
    "preview": "{% extends \"!layout.html\" %}\n\n{% block htmltitle %}\n\n    <!-- Global site tag (gtag.js) - Google Analytics -->\n    <scri"
  },
  {
    "path": "docs/conf.py",
    "chars": 4893,
    "preview": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common op"
  },
  {
    "path": "docs/index.rst",
    "chars": 3787,
    "preview": ".. raw:: html\n\n    <h1 style=\"text-align: center\">\n    VirTex: Learning Visual Representations from Textual Annotations\n"
  },
  {
    "path": "docs/virtex/config.rst",
    "chars": 228,
    "preview": "virtex.config\n=============\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.config\n\n\nConfig References\n----------------"
  },
  {
    "path": "docs/virtex/data.datasets.rst",
    "chars": 434,
    "preview": "virtex.data.datasets\n====================\n\n.. raw:: html\n\n    <hr>\n\nPretraining Datasets\n--------------------\n\n.. automo"
  },
  {
    "path": "docs/virtex/data.rst",
    "chars": 123,
    "preview": "virtex.data\n===========\n\n.. raw:: html\n\n    <hr>\n\n\n.. toctree::\n\n    data.datasets\n    data.tokenizers\n    data.transfor"
  },
  {
    "path": "docs/virtex/data.tokenizers.rst",
    "chars": 111,
    "preview": "virtex.data.tokenizers\n======================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.data.tokenizers\n"
  },
  {
    "path": "docs/virtex/data.transforms.rst",
    "chars": 111,
    "preview": "virtex.data.transforms\n======================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.data.transforms\n"
  },
  {
    "path": "docs/virtex/factories.rst",
    "chars": 1366,
    "preview": "virtex.factories\n================\n\n.. raw:: html\n\n    <hr>\n\n.. First only include the top-level module, and base class d"
  },
  {
    "path": "docs/virtex/model_zoo.rst",
    "chars": 103,
    "preview": "virtex.model_zoo\n================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.model_zoo.model_zoo\n"
  },
  {
    "path": "docs/virtex/models.rst",
    "chars": 344,
    "preview": "virtex.models\n=============\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.models.classification\n\n--------------------"
  },
  {
    "path": "docs/virtex/modules.embedding.rst",
    "chars": 117,
    "preview": "virtex.modules.embedding\n========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.modules.embedding\n"
  },
  {
    "path": "docs/virtex/modules.rst",
    "chars": 147,
    "preview": "virtex.modules\n==============\n\n.. raw:: html\n\n    <hr>\n\n.. toctree::\n\n    modules.embedding\n    modules.visual_backbones"
  },
  {
    "path": "docs/virtex/modules.textual_heads.rst",
    "chars": 129,
    "preview": "virtex.modules.textual_heads\n============================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.modules.textu"
  },
  {
    "path": "docs/virtex/modules.visual_backbones.rst",
    "chars": 138,
    "preview": "virtex.modules.visual_backbones\n===============================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.modules"
  },
  {
    "path": "docs/virtex/optim.lookahead.rst",
    "chars": 111,
    "preview": "virtex.optim.lookahead\n======================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.optim.lookahead\n"
  },
  {
    "path": "docs/virtex/optim.lr_scheduler.rst",
    "chars": 120,
    "preview": "virtex.optim.lr_scheduler\n=========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.optim.lr_scheduler\n"
  },
  {
    "path": "docs/virtex/optim.rst",
    "chars": 109,
    "preview": "virtex.optim\n============\n\n.. raw:: html\n\n    <hr>\n\n.. toctree::\n\n    optim.lookahead\n    optim.lr_scheduler\n"
  },
  {
    "path": "docs/virtex/usage/downstream.rst",
    "chars": 8344,
    "preview": "How to evaluate on downstream tasks?\n====================================\n\nIn our paper, we evaluate our pretrained VirT"
  },
  {
    "path": "docs/virtex/usage/model_zoo.rst",
    "chars": 10636,
    "preview": "VirTex Model Zoo\n================\n\nWe provide a collection of pretrained model weights and corresponding config\nnames in"
  },
  {
    "path": "docs/virtex/usage/pretrain.rst",
    "chars": 3433,
    "preview": "How to train your VirTex model?\n===============================\n\nWe provide training scripts for all type of VirTex mode"
  },
  {
    "path": "docs/virtex/usage/setup_dependencies.rst",
    "chars": 3326,
    "preview": "How to setup this codebase?\n===========================\n\n.. raw:: html\n\n    <hr>\n\nThis codebase requires Python 3.6+ or "
  },
  {
    "path": "docs/virtex/utils.beam_search.rst",
    "chars": 117,
    "preview": "virtex.utils.beam_search\n========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.beam_search\n"
  },
  {
    "path": "docs/virtex/utils.checkpointing.rst",
    "chars": 123,
    "preview": "virtex.utils.checkpointing\n==========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.checkpointi"
  },
  {
    "path": "docs/virtex/utils.common.rst",
    "chars": 102,
    "preview": "virtex.utils.common\n===================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.common\n"
  },
  {
    "path": "docs/virtex/utils.distributed.rst",
    "chars": 117,
    "preview": "virtex.utils.distributed\n========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.distributed\n"
  },
  {
    "path": "docs/virtex/utils.metrics.rst",
    "chars": 105,
    "preview": "virtex.utils.metrics\n====================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.metrics\n"
  },
  {
    "path": "docs/virtex/utils.rst",
    "chars": 185,
    "preview": "virtex.utils\n============\n\n.. raw:: html\n\n    <hr>\n\n.. toctree::\n\n    utils.common\n    utils.distributed\n    utils.timer"
  },
  {
    "path": "docs/virtex/utils.timer.rst",
    "chars": 99,
    "preview": "virtex.utils.timer\n==================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.timer\n"
  },
  {
    "path": "hubconf.py",
    "chars": 1209,
    "preview": "dependencies = [\"torch\"]\n\nimport torch\nimport torchvision\n\n\nR50_URL = \"https://www.dropbox.com/s/pxgjxcva7oypf12/backbon"
  },
  {
    "path": "requirements.txt",
    "chars": 185,
    "preview": "albumentations>=1.0\nCython>=0.25\nfuture==0.18.0\nloguru>=0.3\nlvis>=0.5\nnumpy>=1.17\nopencv-python>=4.2.0\nscikit-learn>=1.0"
  },
  {
    "path": "scripts/build_vocabulary.py",
    "chars": 3127,
    "preview": "import argparse\nimport json\nimport os\nimport tempfile\nimport unicodedata\nfrom typing import List\n\nimport sentencepiece a"
  },
  {
    "path": "scripts/clf_linear.py",
    "chars": 11584,
    "preview": "import argparse\nimport os\n\nfrom loguru import logger\nimport torch\nfrom torch import nn\nfrom torch.cuda import amp\nfrom t"
  },
  {
    "path": "scripts/clf_voc07.py",
    "chars": 9851,
    "preview": "import argparse\nimport multiprocessing as mp\nimport os\nfrom typing import Any, List\n\nimport numpy as np\nimport torch\nfro"
  },
  {
    "path": "scripts/eval_captioning.py",
    "chars": 4100,
    "preview": "import argparse\nimport json\nimport os\nfrom typing import Any, Dict, List\n\nfrom loguru import logger\nimport torch\nfrom to"
  },
  {
    "path": "scripts/eval_detectron2.py",
    "chars": 9172,
    "preview": "\"\"\"\nFinetune a pre-trained model on a downstream task, one of those available in\nDetectron2.\nSupported downstream:\n  - L"
  },
  {
    "path": "scripts/pretrain_virtex.py",
    "chars": 8743,
    "preview": "import argparse\nfrom collections import Counter\nfrom typing import Any\n\nfrom loguru import logger\nimport torch\nfrom torc"
  },
  {
    "path": "setup.py",
    "chars": 1572,
    "preview": "#!/usr/bin/env python\nimport glob\nimport os\nfrom setuptools import setup\nimport shutil\nfrom typing import List\n\n\ndef get"
  },
  {
    "path": "virtex/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "virtex/config.py",
    "chars": 10412,
    "preview": "from typing import Any, List, Optional\n\nfrom fvcore.common.config import CfgNode as CN\n\n\nclass Config:\n    r\"\"\"\n    This"
  },
  {
    "path": "virtex/data/__init__.py",
    "chars": 633,
    "preview": "from .datasets.captioning import CaptioningDataset\nfrom .datasets.classification import (\n    TokenClassificationDataset"
  },
  {
    "path": "virtex/data/datasets/captioning.py",
    "chars": 4064,
    "preview": "import random\nfrom typing import Callable, Dict, List\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import Data"
  },
  {
    "path": "virtex/data/datasets/classification.py",
    "chars": 7374,
    "preview": "from collections import defaultdict\nimport glob\nimport json\nimport os\nimport random\nfrom typing import Any, Callable, Di"
  },
  {
    "path": "virtex/data/datasets/coco_captions.py",
    "chars": 2300,
    "preview": "from collections import defaultdict\nimport json\nimport os\nimport unicodedata\nfrom typing import Dict, List\n\nimport cv2\nf"
  },
  {
    "path": "virtex/data/datasets/downstream.py",
    "chars": 8664,
    "preview": "from collections import defaultdict\nimport glob\nimport json\nimport os\nfrom typing import Callable, Dict, List, Tuple\n\nim"
  },
  {
    "path": "virtex/data/datasets/masked_lm.py",
    "chars": 5051,
    "preview": "import math\nimport random\nfrom typing import Callable, Dict, List\n\nimport albumentations as alb\nimport numpy as np\nimpor"
  },
  {
    "path": "virtex/data/tokenizers.py",
    "chars": 2051,
    "preview": "from typing import Any, Dict, List\n\nimport sentencepiece as sp\n\n\nclass SentencePieceBPETokenizer:\n    r\"\"\"\n    A tokeniz"
  },
  {
    "path": "virtex/data/transforms.py",
    "chars": 3262,
    "preview": "import albumentations as alb\nimport cv2\n\n\nclass HorizontalFlip(alb.BasicTransform):\n    r\"\"\"\n    Flip the image horizont"
  },
  {
    "path": "virtex/factories.py",
    "chars": 21780,
    "preview": "r\"\"\"\nThis module is a collection of *factories* for creating objects of datasets,\nmodels, optimizers and other useful co"
  },
  {
    "path": "virtex/model_zoo/__init__.py",
    "chars": 46,
    "preview": "from .model_zoo import get\n\n__all__ = [\"get\"]\n"
  },
  {
    "path": "virtex/model_zoo/model_zoo.py",
    "chars": 4425,
    "preview": "r\"\"\"\nA utility module to easily load common VirTex models (optionally with pretrained\nweights) using a single line of co"
  },
  {
    "path": "virtex/models/__init__.py",
    "chars": 431,
    "preview": "from .captioning import (\n    ForwardCaptioningModel,\n    BidirectionalCaptioningModel,\n    VirTexModel\n)\nfrom .masked_l"
  },
  {
    "path": "virtex/models/captioning.py",
    "chars": 11286,
    "preview": "import copy\nimport functools\nfrom typing import Any, Dict\n\nimport torch\nfrom torch import nn\n\nfrom virtex.data.tokenizer"
  },
  {
    "path": "virtex/models/classification.py",
    "chars": 6813,
    "preview": "from typing import Any, Dict, List\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom virtex."
  },
  {
    "path": "virtex/models/masked_lm.py",
    "chars": 4066,
    "preview": "from typing import Any, Dict\n\nimport torch\nfrom torch import nn\n\nfrom virtex.data.tokenizers import SentencePieceBPEToke"
  },
  {
    "path": "virtex/modules/embedding.py",
    "chars": 3428,
    "preview": "import functools\n\nimport torch\nfrom torch import nn\n\n\nclass WordAndPositionalEmbedding(nn.Module):\n    r\"\"\"\n    A :class"
  },
  {
    "path": "virtex/modules/textual_heads.py",
    "chars": 12233,
    "preview": "r\"\"\"\nA textual head accepts visual features from the visual backbone, and performs\ntask specific modeling (captioning, c"
  },
  {
    "path": "virtex/modules/visual_backbones.py",
    "chars": 4266,
    "preview": "from typing import Any, Dict\n\nimport torch\nfrom torch import nn\nimport torchvision\n\n\nclass VisualBackbone(nn.Module):\n  "
  },
  {
    "path": "virtex/optim/__init__.py",
    "chars": 58,
    "preview": "from .lookahead import Lookahead\n\n__all__ = [\"Lookahead\"]\n"
  },
  {
    "path": "virtex/optim/lookahead.py",
    "chars": 4496,
    "preview": "r\"\"\"\n`Lookahead Optimizer: k steps forward, 1 step back <https://arxiv.org/abs/1907.08610>`_.\n\nThis implementation is ad"
  },
  {
    "path": "virtex/optim/lr_scheduler.py",
    "chars": 6418,
    "preview": "import bisect\nimport math\nfrom typing import List\n\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler impor"
  },
  {
    "path": "virtex/utils/beam_search.py",
    "chars": 10686,
    "preview": "r\"\"\"\nThis Beam Search implementation is adapted with minor modifications from\n`AllenNLP <https://github.com/allenai/alle"
  },
  {
    "path": "virtex/utils/checkpointing.py",
    "chars": 6660,
    "preview": "import copy\nimport pathlib\nfrom typing import Any, Dict, List, Optional\n\nfrom loguru import logger\nimport torch\nfrom tor"
  },
  {
    "path": "virtex/utils/common.py",
    "chars": 5522,
    "preview": "import argparse\nimport os\nimport random\nimport sys\n\nfrom loguru import logger\nimport numpy as np\nimport torch\n\nfrom virt"
  },
  {
    "path": "virtex/utils/distributed.py",
    "chars": 5853,
    "preview": "r\"\"\"\nA collection of common utilities for distributed training. These are a bunch of\nwrappers over utilities from :mod:`"
  },
  {
    "path": "virtex/utils/metrics.py",
    "chars": 11204,
    "preview": "r\"\"\"\nThis module is a collection of metrics commonly used during pretraining and\ndownstream evaluation. Two main classes"
  },
  {
    "path": "virtex/utils/nucleus_sampling.py",
    "chars": 5064,
    "preview": "r\"\"\"\nNucleus Sampling was introduced in the paper\n`The Curious Case of Neural Text Degeneration <https://arxiv.org/abs/1"
  },
  {
    "path": "virtex/utils/timer.py",
    "chars": 1988,
    "preview": "import time\nfrom typing import Optional\n\n\nclass Timer:\n    r\"\"\"\n    A simple timer to record time per iteration and ETA "
  }
]

About this extraction

This page contains the full source code of the kdexd/virtex GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 99 files (274.8 KB), approximately 69.4k tokens, and a symbol index with 215 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!