[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# Code Editors\n.vscode\n.idea\n\n# Code linters\n.mypy_cache\n\n# Datasets and preprocessed files\ndata/\n!virtex/data\n\n# IPython Notebook\n.ipynb_checkpoints\n\n# virtualenv\nvenv/\nENV/\n\n# Temporary scripts to (smoke) test out bits and pieces of code.\nscripts/test_*\n\n# Data (symlinks) directory, model checkpoints, tensorboard logs etc.\ndatasets/\ncheckpoints/\nvirtex/utils/assets/\n!virtex/data/datasets/\nvirtex/model_zoo/configs\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "CHANGELOG\n=========\n\nThis CHANGELOG file records changes between different arXiv versions of our paper, and the version of this codebase which should be used to reproduce the results in the corresponding arXiv version. View changes between code versions on the [Releases page](https://github.com/kdexd/virtex/releases).\n\nArXiv v1 -> v2\n==============\n\n**Code version:** `v1.2`.\n\nFix image captioning results with a modified beam search implementation. _Rest of the downstream task results and pre-trained models are unchanged._\n\n\nArXiv v1 -> v2\n==============\n\n**Code version:** `v1.0` or `v1.1`.\n\n[ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is our CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0).\n\nWhile the core motivation and approach is the same, we have made some minor changes in our experiments and evaluation setup. These slightly improve model performances across the board (within decimals). New models are available in [`v1.0` model zoo](http://kdexd.github.io/virtex/virtex/usage/model_zoo.html), however links to old models in `v0.9` will be active till June 30, 2021. We encourage you to use the new models!\n\nWe have updated the experiment config files for all changes described below.\n\nExperiment Changes\n------------------\n\n### New Feature:\n\nAdd a new pretraining task for BERT-style _Masked Language Modeling_. Pre-trained model released in Model Zoo.\n\n### Pre-training:\n\n- The only change during pre-training is that we do not apply weight decay to LayerNorm and biases in input embedding and transformer layers. We apply weight decay to the biases in output linear layer (before softmax).\n\n- Other factors that could affect results:\n  - Use official [albumentations.ColorJitter transform](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ColorJitter) that mimics torchvision ColorJitter transform. Earlier I implemented [my own ColorJitter](https://github.com/kdexd/virtex/blob/c19e7fc9b98e98af82286ed1537b6f588eaeac44/virtex/data/transforms.py#L156) because albumentations didn't have one.\n  - Use PyTorch Native AMP (Automatic Mixed Precision) instead of NVIDIA Apex.\n\n### Downstream Evaluations:\n\n1. **PASCAL VOC 2007 Linear Classification:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-b4405dd4879a48ef1e5b1e2801035909584a5f1f32f63d5e793fb50dee077b97)\n   - Instead of training linear SVMs on 8192-dimensional average pooled features from ResNet-50 (7x7x2048 —> 2x2x2048), like [(Misra et al. 2019)](https://arxiv.org/abs/1905.01235), we directly train SVMs on 2048-dimensional global average pooled features, following recent works like [SwAV (Caron et al. 2020)](https://arxiv.org/abs/2006.09882).\n   - We change the pre-processing: resize shortest edge to 256 pixels, and take center crop of 224 pixels.\n   - These improve VOC mAP by 1-2 points everywhere, and makes SVM training faster. Since we select best checkpoint based on this metric, all results on other downstream tasks also change in `ArXiv v2` (But the trends remain same.)\n\n2. **ImageNet Linear Evaluation:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-d3dea1e7bf97d0cfca4b59a47c0a9bb81e78b8827654fe0258df9ce2c3f5f41c)\n   - Changed random resized crop scale from (20-100%) to (8-100%) for consistency with evaluations in SSL works like MoCo and SwAV.\n   - Use cosine LR decay instead of step decay, following SwAV. Improves accuracy by up to 1%.\n\n3. **iNaturalist Fine-tuning:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-09096da78cfcde3a604ce22d80313f0800225d928cce5ef7334b89a382adfe4d)\n   - This evaluation is left unchanged across ArXiv versions, but we fixd a typo in image pre-processing step, present in publicly released config.\n\n4. **Detectron2 tasks (COCO and LVIS Instance Segmentation, VOC Detection):**\n   - Heavily simplified the script. Updated Detectron2 uses a more memory-efficient SyncBatchNorm and supports AMP.\n\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2020, Karan Desai.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and\nassociated documentation files (the \"Software\"), to deal in the Software without restriction,\nincluding without limitation the rights to use, copy, modify, merge, publish, distribute,\nsublicense, and/or sell copies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial\nportions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT\nNOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\nNONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES\nOR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN\nCONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "VirTex: Learning Visual Representations from Textual Annotations\n================================================================\n\n<h4>\nKaran Desai and Justin Johnson\n</br>\n<span style=\"font-size: 14pt; color: #555555\">\nUniversity of Michigan\n</span>\n</h4>\n<hr>\n\n**CVPR 2021** [arxiv.org/abs/2006.06666][1]\n\n**Model Zoo, Usage Instructions and API docs:** [kdexd.github.io/virtex](https://kdexd.github.io/virtex)\n\nVirTex is a pretraining approach which uses semantically dense captions to\nlearn visual representations. We train CNN + Transformers from scratch on\nCOCO Captions, and transfer the CNN to downstream vision tasks including\nimage classification, object detection, and instance segmentation.\nVirTex matches or outperforms models which use ImageNet for pretraining -- \nboth supervised or unsupervised -- despite using up to 10x fewer images.\n\n![virtex-model](docs/_static/system_figure.jpg)\n\n\nGet the pretrained ResNet-50 visual backbone from our best performing VirTex\nmodel in one line *without any installation*!\n\n```python\nimport torch\n\n# That's it, this one line only requires PyTorch.\nmodel = torch.hub.load(\"kdexd/virtex\", \"resnet50\", pretrained=True)\n```\n\n### Note (For returning users before January 2021):\n\nThe pretrained models in our model zoo have changed from [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0) onwards.\nThey are slightly better tuned than older models, and reproduce the results in our\nCVPR 2021 accepted paper ([arXiv v2](https://arxiv.org/abs/2006.06666v2)). \nSome training and evaluation hyperparams are changed since [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9).\nPlease refer [`CHANGELOG.md`](https://github.com/kdexd/virtex/blob/master/CHANGELOG.md)\n\n\nUsage Instructions\n------------------\n\n1. [How to setup this codebase?][2]  \n2. [VirTex Model Zoo][3]  \n3. [How to train your VirTex model?][4]  \n4. [How to evaluate on downstream tasks?][5]  \n\nFull documentation is available at [kdexd.github.io/virtex](https://kdexd.github.io/virtex).\n\n\nCitation\n--------\n\nIf you find this code useful, please consider citing:\n\n```text\n@inproceedings{desai2021virtex,\n    title={{VirTex: Learning Visual Representations from Textual Annotations}},\n    author={Karan Desai and Justin Johnson},\n    booktitle={CVPR},\n    year={2021}\n}\n```\n\nAcknowledgments\n---------------\n\nWe thank Harsh Agrawal, Mohamed El Banani, Richard  Higgins, Nilesh Kulkarni\nand Chris Rockwell for helpful discussions and feedback on the paper. We thank\nIshan Misra for discussions regarding PIRL evaluation protocol; Saining Xie for\ndiscussions about replicating iNaturalist evaluation as MoCo; Ross Girshick and\nYuxin Wu for help with Detectron2 model zoo; Georgia Gkioxari for suggesting\nthe Instance Segmentation pretraining task ablation; and Stefan Lee for\nsuggestions on figure aesthetics. We thank Jia Deng for access to extra GPUs\nduring project development; and UMich ARC-TS team for support with GPU cluster\nmanagement. Finally, we thank all the Starbucks outlets in Ann Arbor for many\nhours of free WiFi. This work was partially supported by the Toyota Research\nInstitute (TRI). However, note that this article solely reflects the opinions\nand conclusions of its authors and not TRI or any other Toyota entity.\n\n\n[1]: https://arxiv.org/abs/2006.06666\n[2]: https://kdexd.github.io/virtex/virtex/usage/setup_dependencies.html\n[3]: https://kdexd.github.io/virtex/virtex/usage/model_zoo.html\n[4]: https://kdexd.github.io/virtex/virtex/usage/pretrain.html\n[5]: https://kdexd.github.io/virtex/virtex/usage/downstream.html\n"
  },
  {
    "path": "configs/_base_bicaptioning_R_50_L1_H1024.yaml",
    "content": "# -----------------------------------------------------------------------------\n# Base config: VirTex pretraining for our \"base\" bicaptioning model:\n# ResNet-50 + (L = 1, H = 1024) transformer trained for 500K iterations.\n# -----------------------------------------------------------------------------\nRANDOM_SEED: 0\nAMP: true\nCUDNN_BENCHMARK: true\nCUDNN_DETERMINISTIC: false\n\nDATA:\n  ROOT: \"datasets/coco\"\n  TOKENIZER_MODEL: \"datasets/vocab/coco_10k.model\"\n  VOCAB_SIZE: 10000\n  UNK_INDEX: 0\n  SOS_INDEX: 1\n  EOS_INDEX: 2\n  MASK_INDEX: 3\n\n  IMAGE_CROP_SIZE: 224\n  MAX_CAPTION_LENGTH: 30\n\n  IMAGE_TRANSFORM_TRAIN:\n    - \"random_resized_crop\"\n    - \"horizontal_flip\"\n    - \"color_jitter\"\n    - \"normalize\"\n\n  IMAGE_TRANSFORM_VAL:\n    - \"smallest_resize\"\n    - \"center_crop\"\n    - \"normalize\"\n\nMODEL:\n  NAME: \"virtex\"\n\n  VISUAL:\n    NAME: \"torchvision::resnet50\"\n    PRETRAINED: false\n    FROZEN: false\n\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H1024_A16_F4096\"\n    DROPOUT: 0.1\n\n  DECODER:\n    NAME: \"beam_search\"\n    BEAM_SIZE: 5\n\nOPTIM:\n  OPTIMIZER_NAME: \"sgd\"\n  SGD_MOMENTUM: 0.9\n  WEIGHT_DECAY: 0.0001\n\n  LOOKAHEAD:\n    USE: true\n    ALPHA: 0.5\n    STEPS: 5\n\n  BATCH_SIZE: 256\n  CNN_LR: 0.2\n  LR: 0.001\n  NUM_ITERATIONS: 500000\n\n  WARMUP_STEPS: 10000\n  LR_DECAY_NAME: \"cosine\"\n\n  NO_DECAY: \".*textual.(embedding|transformer).*(norm.*|bias)\"\n  CLIP_GRAD_NORM: 10.0\n\n"
  },
  {
    "path": "configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  VISUAL:\n    NAME: \"torchvision::resnet101\"\n"
  },
  {
    "path": "configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  VISUAL:\n    NAME: \"torchvision::wide_resnet50_2\"\n"
  },
  {
    "path": "configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n"
  },
  {
    "path": "configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n"
  },
  {
    "path": "configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L2_H1024_A16_F4096\"\n"
  },
  {
    "path": "configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L3_H1024_A16_F4096\"\n"
  },
  {
    "path": "configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L4_H1024_A16_F4096\"\n"
  },
  {
    "path": "configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml",
    "content": "# ----------------------------------------------------------------------------\n# Train a Faster R-CNN with ResNet-50 and C4 backbone. This config follows\n# Detectron2 format; and is unrelated with our VirTex configs. Params here\n# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722).\n# ----------------------------------------------------------------------------\n\nINPUT:\n  # Input format will always be RGB, consistent with torchvision.\n  FORMAT: \"RGB\"\n  MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)\n  MIN_SIZE_TEST: 800\n\nMODEL:\n  META_ARCHITECTURE: \"GeneralizedRCNN\"\n\n  # Train all layers end-to-end by default.\n  BACKBONE:\n    NAME: build_resnet_backbone\n    FREEZE_AT: 0\n\n  # Fine-tune with SyncBN.\n  # STRIDE_IN_1X1 is False for torchvision-like models.\n  RESNETS:\n    DEPTH: 50\n    NORM: SyncBN\n    STRIDE_IN_1X1: False\n\n  RPN:\n    PRE_NMS_TOPK_TEST: 6000\n    POST_NMS_TOPK_TEST: 1000\n\n  # ROI head with extra BN layer after res5 stage.\n  ROI_HEADS:\n    NAME: \"Res5ROIHeadsExtraNorm\"\n\n  # ImageNet color mean for torchvision-like models (RGB order).\n  PIXEL_MEAN: [123.675, 116.280, 103.530]\n  PIXEL_STD: [58.395, 57.120, 57.375]\n\nSOLVER:\n  # This is for 8 GPUs, apply linear scaling for 4 GPUs.\n  IMS_PER_BATCH: 16\n  BASE_LR: 0.02\n\nTEST:\n  PRECISE_BN:\n    ENABLED: True\n\nVERSION: 2\n"
  },
  {
    "path": "configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml",
    "content": "# ----------------------------------------------------------------------------\n# Train a Mask R-CNN with ResNet-50 and FPN backbone. This config follows\n# Detectron2 format; and is unrelated with our VirTex configs. Params here\n# replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722).\n# ----------------------------------------------------------------------------\n\nINPUT:\n  # Input format will always be RGB, consistent with torchvision.\n  FORMAT: \"RGB\"\n  MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)\n  MIN_SIZE_TEST: 800\n\nMODEL:\n  META_ARCHITECTURE: \"GeneralizedRCNN\"\n\n  # Train all layers end-to-end by default.\n  BACKBONE:\n    NAME: \"build_resnet_fpn_backbone\"\n    FREEZE_AT: 0\n\n  # Fine-tune with SyncBN.\n  # STRIDE_IN_1X1 is False for torchvision-like models.\n  RESNETS:\n    DEPTH: 50\n    NORM: \"SyncBN\"\n    STRIDE_IN_1X1: False\n    OUT_FEATURES: [\"res2\", \"res3\", \"res4\", \"res5\"]\n\n  FPN:\n    IN_FEATURES: [\"res2\", \"res3\", \"res4\", \"res5\"]\n\n  ANCHOR_GENERATOR:\n    # One size for each in feature map\n    SIZES: [[32], [64], [128], [256], [512]]\n    # Three aspect ratios (same for all in feature maps)\n    ASPECT_RATIOS: [[0.5, 1.0, 2.0]]\n\n  RPN:\n    IN_FEATURES: [\"p2\", \"p3\", \"p4\", \"p5\", \"p6\"]\n    PRE_NMS_TOPK_TRAIN: 2000\n    PRE_NMS_TOPK_TEST: 1000\n\n    POST_NMS_TOPK_TRAIN: 1000\n    POST_NMS_TOPK_TEST: 1000\n\n  ROI_HEADS:\n    NAME: \"StandardROIHeads\"\n    IN_FEATURES: [\"p2\", \"p3\", \"p4\", \"p5\"]\n\n  ROI_BOX_HEAD:\n    NAME: \"FastRCNNConvFCHead\"\n    NUM_FC: 2\n    POOLER_RESOLUTION: 7\n\n  ROI_MASK_HEAD:\n    NAME: \"MaskRCNNConvUpsampleHead\"\n    NUM_CONV: 4\n    POOLER_RESOLUTION: 14\n\n  # ImageNet color mean for torchvision-like models (RGB order).\n  # These are in [0-255] range as expected by Detectron2. Rest of our codebase\n  # uses [0-1] range; but both are equivalent and consistent.\n  PIXEL_MEAN: [123.675, 116.280, 103.530]\n  PIXEL_STD: [58.395, 57.120, 57.375]\n\nSOLVER:\n  # This is for 8 GPUs, apply linear scaling for 4 GPUs.\n  IMS_PER_BATCH: 16\n  BASE_LR: 0.02\n\nTEST:\n  PRECISE_BN:\n    ENABLED: True\n\nVERSION: 2\n"
  },
  {
    "path": "configs/detectron2/coco_segm_default_init_2x.yaml",
    "content": "# -----------------------------------------------------------------------------\n# Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of\n# these weight init: random, imagenet (torchvision), virtex or MoCo.\n# -----------------------------------------------------------------------------\n_BASE_: \"_base_mask_rcnn_R_50_FPN.yaml\"\n\nDATASETS:\n  TRAIN: (\"coco_2017_train\",)\n  TEST: (\"coco_2017_val\",)\n\nMODEL:\n  MASK_ON: True\n  # FPN also has SyncBN, as opposed to no norm (usually).\n  FPN:\n    NORM: \"SyncBN\"\n  \n  # This will be ignored, weights will be loaded manually in the script.\n  WEIGHTS: \"\"\n  \nSOLVER:\n  STEPS: (120000, 160000)\n  MAX_ITER: 180000\n  \nVERSION: 2\n"
  },
  {
    "path": "configs/detectron2/lvis_segm_default_init_2x.yaml",
    "content": "# -----------------------------------------------------------------------------\n# Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of\n# these weight init: random, virtex or MoCo. (ImageNet init config is separate)\n# -----------------------------------------------------------------------------\n_BASE_: \"_base_mask_rcnn_R_50_FPN.yaml\"\n\nDATASETS:\n  TRAIN: (\"lvis_v1_train\",)\n  TEST: (\"lvis_v1_val\",)\n\nDATALOADER:\n  SAMPLER_TRAIN: \"RepeatFactorTrainingSampler\"\n  REPEAT_THRESHOLD: 0.001\n\nTEST:\n  DETECTIONS_PER_IMAGE: 300  # LVIS allows up to 300.\n\nMODEL:\n  MASK_ON: True\n  # FPN also has SyncBN, as opposed to no norm (usually).\n  FPN:\n    NORM: \"SyncBN\"\n\n  ROI_HEADS:\n    NUM_CLASSES: 1203\n    SCORE_THRESH_TEST: 0.0001\n\n  # This will be ignored, weights will be loaded manually in the script.\n  WEIGHTS: \"\"\n\nSOLVER:\n  STEPS: (120000, 160000)\n  MAX_ITER: 180000\n\nVERSION: 2\n\n"
  },
  {
    "path": "configs/detectron2/lvis_segm_imagenet_init_2x.yaml",
    "content": "# -----------------------------------------------------------------------------\n# Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation\n# with weights initialized from supervised ImageNet pretraining (torchvision).\n# Key difference is that fine-tuning here happens with BN frozen.\n# -----------------------------------------------------------------------------\n_BASE_: \"_base_mask_rcnn_R_50_FPN.yaml\"\n\nDATASETS:\n  TRAIN: (\"lvis_v1_train\",)\n  TEST: (\"lvis_v1_val\",)\n\nDATALOADER:\n  SAMPLER_TRAIN: \"RepeatFactorTrainingSampler\"\n  REPEAT_THRESHOLD: 0.001\n\nTEST:\n  DETECTIONS_PER_IMAGE: 300  # LVIS allows up to 300.\n\nMODEL:\n  MASK_ON: True\n  RESNETS:\n    NORM: \"FrozenBN\"\n\n  # Do not tune with SyncBN for ImageNet init from LVIS.\n  ROI_HEADS:\n    NUM_CLASSES: 1203\n    SCORE_THRESH_TEST: 0.0001\n\n  # This will be ignored, weights will be loaded manually in the script.\n  WEIGHTS: \"\"\n\nSOLVER:\n  STEPS: (120000, 160000)\n  MAX_ITER: 180000\n\nVERSION: 2\n\n\n"
  },
  {
    "path": "configs/detectron2/voc_det_default_init_24k.yaml",
    "content": "# -----------------------------------------------------------------------------\n# Train a Faster R-CNN with R50-C4 backbone on VOC07+12 detection with any of\n# these weight init: random, imagenet (torchvision), virtex or MoCo.\n# -----------------------------------------------------------------------------\n_BASE_: \"_base_faster_rcnn_R_50_C4_BN.yaml\"\n\nDATASETS:\n  TRAIN: (\"voc_2007_trainval\", \"voc_2012_trainval\")\n  TEST: (\"voc_2007_test\",)\n\nINPUT:\n  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)\n  MIN_SIZE_TEST: 800\n\nMODEL:\n  MASK_ON: False\n  ROI_HEADS:\n    NUM_CLASSES: 20\n\n  # This will be ignored, weights will be loaded manually in the script.\n  WEIGHTS: \"\"\n\nSOLVER:\n  STEPS: (18000, 22000)\n  MAX_ITER: 24000\n  WARMUP_ITERS: 100\n\nVERSION: 2\n"
  },
  {
    "path": "configs/downstream/imagenet_clf.yaml",
    "content": "RANDOM_SEED: 0\n# Don't need AMP to train a tiny linear layer.\nAMP: false\nCUDNN_BENCHMARK: true\nCUDNN_DETERMINISTIC: false\n\nDATA:\n  ROOT: \"datasets/imagenet\"\n  IMAGE_TRANSFORM_TRAIN:\n    - \"random_resized_crop::{'scale': (0.08, 1.0)}\"\n    - \"horizontal_flip\"\n    - \"normalize\"\n  IMAGE_TRANSFORM_VAL:\n    - \"smallest_resize\"\n    - \"center_crop\"\n    - \"normalize\"\n\nMODEL:\n  VISUAL:\n    FROZEN: true\n\nOPTIM:\n  BATCH_SIZE: 256\n  SGD_MOMENTUM: 0.9\n  WEIGHT_DECAY: 0.0\n  NO_DECAY: \"none\"\n  LOOKAHEAD:\n    USE: false\n\n  LR: 0.3\n  WARMUP_STEPS: 0\n  LR_DECAY_NAME: \"cosine\"\n  NUM_ITERATIONS: 500500  # 100 epochs\n"
  },
  {
    "path": "configs/downstream/inaturalist_clf.yaml",
    "content": "RANDOM_SEED: 0\nAMP: true\nCUDNN_BENCHMARK: true\nCUDNN_DETERMINISTIC: false\n\nDATA:\n  ROOT: \"datasets/inaturalist\"\n  IMAGE_TRANSFORM_TRAIN:\n    - \"random_resized_crop::{'scale': (0.08, 1.0)}\"\n    - \"horizontal_flip\"\n    - \"normalize\"\n  IMAGE_TRANSFORM_VAL:\n    - \"smallest_resize\"\n    - \"center_crop\"\n    - \"normalize\"\n\nMODEL:\n  VISUAL:\n    FROZEN: false\n    \nOPTIM:\n  BATCH_SIZE: 256\n  SGD_MOMENTUM: 0.9\n  WEIGHT_DECAY: 0.0001\n  NO_DECAY: \"none\"\n  LOOKAHEAD:\n    USE: false\n\n  LR: 0.025\n  WARMUP_STEPS: 0\n  LR_DECAY_NAME: multistep\n  LR_GAMMA: 0.1\n  LR_STEPS:\n    - 119700  # 70 epochs\n    - 153900  # 90 epochs\n  NUM_ITERATIONS: 171000  # 100 epochs\n"
  },
  {
    "path": "configs/downstream/voc07_clf.yaml",
    "content": "RANDOM_SEED: 0\nDATA:\n  ROOT: datasets/VOC2007\n  IMAGE_TRANSFORM_TRAIN:\n    - smallest_resize\n    - center_crop\n    - normalize\n  IMAGE_TRANSFORM_VAL:\n    - smallest_resize\n    - center_crop\n    - normalize\n\nOPTIM:\n  # Only used for feature extraction, doesn't mean much.\n  BATCH_SIZE: 128\n"
  },
  {
    "path": "configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H2048_A32_F8192\"\n"
  },
  {
    "path": "configs/task_ablations/captioning_R_50_L1_H2048.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  NAME: \"captioning\"\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H2048_A32_F8192\"\n"
  },
  {
    "path": "configs/task_ablations/masked_lm_R_50_L1_H2048.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  NAME: \"masked_lm\"\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H2048_A32_F8192\"\n"
  },
  {
    "path": "configs/task_ablations/multilabel_classification_R_50.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nDATA:\n  VOCAB_SIZE: 81\n\nMODEL:\n  NAME: \"multilabel_classification\"\n  TEXTUAL:\n    NAME: \"none\"\n\nOPTIM:\n  NO_DECAY: \"none\"\n"
  },
  {
    "path": "configs/task_ablations/token_classification_R_50.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  NAME: \"token_classification\"\n  TEXTUAL:\n    NAME: \"none\"\n\nOPTIM:\n  NO_DECAY: \"none\"\n"
  },
  {
    "path": "configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n"
  },
  {
    "path": "configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H2048_A32_F8192\"\n"
  },
  {
    "path": "configs/width_ablations/bicaptioning_R_50_L1_H512.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H512_A8_F2048\"\n"
  },
  {
    "path": "configs/width_ablations/bicaptioning_R_50_L1_H768.yaml",
    "content": "_BASE_: \"../_base_bicaptioning_R_50_L1_H1024.yaml\"\n\nMODEL:\n  TEXTUAL:\n    NAME: \"transdec_postnorm::L1_H768_A12_F3072\"\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = ../../virtex-sphinx\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/_templates/layout.html",
    "content": "{% extends \"!layout.html\" %}\n\n{% block htmltitle %}\n\n    <!-- Global site tag (gtag.js) - Google Analytics -->\n    <script async src=\"https://www.googletagmanager.com/gtag/js?id=UA-120523111-2\"></script>\n    <script>\n    window.dataLayer = window.dataLayer || [];\n    function gtag(){dataLayer.push(arguments);}\n    gtag('js', new Date());\n\n    gtag('config', 'UA-120523111-2');\n    </script>\n\n    <link href=\"https://fonts.googleapis.com/css?family=Inconsolata&display=swap\" rel=\"stylesheet\">\n    <link href=\"https://fonts.googleapis.com/css?family=Ubuntu+Mono&display=swap\" rel=\"stylesheet\">\n\n{{ super() }}\n{% endblock %}\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# http://www.sphinx-doc.org/en/master/config\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport inspect\nimport os\nimport sys\n\nsys.path.insert(0, os.path.abspath(\"../\"))\n\n\n# -- Project information -----------------------------------------------------\n\nproject = \"virtex\"\ncopyright = \"2021, Karan Desai and Justin Johnson\"\nauthor = \"Karan Desai\"\n\n# The full version, including alpha/beta/rc tags\nrelease = \"1.4\"\n\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.coverage\",\n    \"sphinx.ext.doctest\",\n    \"sphinx.ext.linkcode\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx.ext.autosummary\",\n    \"sphinx.ext.coverage\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx.ext.mathjax\",\n    \"sphinx_copybutton\",\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\n# source_suffix = ['.rst', '.md']\nsource_suffix = \".rst\"\n\n# The master toctree document.\nmaster_doc = \"index\"\n\n# The version info for the project you're documenting, acts as replacement for\n# |version| and |release|, also used in various other places throughout the\n# built documents.\n#\n# This version is used underneath the title on the index page.\nversion = \"1.4\"\n# The following is used if you need to also include a more detailed version.\nrelease = \"1.4\"\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = \"en\"\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This patterns also effect to html_static_path and html_extra_path\nexclude_patterns = [\"_build\"]\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = \"sphinx\"\n\n# If true, `todo` and `todoList` produce output, else they produce nothing.\ntodo_include_todos = False\n\nnumpydoc_show_class_members = False\n\n\n# -- Options for HTML output ----------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"sphinx_rtd_theme\"\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"_static\"]\n\n\n# -- Autodoc configuration ------------------------------------------------\n\nautodoc_default_options = {\n    \"members\": True,\n    \"member-order\": \"bysource\",\n    \"private-members\": True,\n    \"show-inheritance\": True,\n}\n\n\n# -- Intersphinx configuration --------------------------------------------\n\nintersphinx_mapping = {\n    \"torch\": (\"https://pytorch.org/docs/stable/\", None),\n    \"albumentations\": (\"https://albumentations.readthedocs.io/en/latest/\", None),\n}\n\n# -- Miscellaneous Extra Tweaks -------------------------------------------\n\n# make github links resolve\ndef linkcode_resolve(domain, info):\n    \"\"\"\n    Determine the URL corresponding to Python object\n    This code is from\n    https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L290\n    and https://github.com/Lasagne/Lasagne/pull/262\n    \"\"\"\n    if domain != \"py\":\n        return None\n\n    modname = info[\"module\"]\n    fullname = info[\"fullname\"]\n\n    submod = sys.modules.get(modname)\n    if submod is None:\n        return None\n\n    obj = submod\n    for part in fullname.split(\".\"):\n        try:\n            obj = getattr(obj, part)\n        except:  # noqa: E722\n            return None\n\n    try:\n        fn = inspect.getsourcefile(obj)\n    except:  # noqa: E722\n        fn = None\n    if not fn:\n        return None\n\n    try:\n        source, lineno = inspect.getsourcelines(obj)\n    except:  # noqa: E722\n        lineno = None\n\n    if lineno:\n        linespec = \"#L%d-L%d\" % (lineno, lineno + len(source) - 1)\n    else:\n        linespec = \"\"\n\n    filename = info[\"module\"].replace(\".\", \"/\")\n    return f\"https://github.com/kdexd/virtex/blob/master/{filename}.py{linespec}\"\n"
  },
  {
    "path": "docs/index.rst",
    "content": ".. raw:: html\n\n    <h1 style=\"text-align: center\">\n    VirTex: Learning Visual Representations from Textual Annotations\n    </h1>\n    <h4 style=\"text-align: center\">\n    Karan Desai and Justin Johnson\n    </br>\n    <span style=\"font-size: 14pt; color: #555555\">\n    University of Michigan\n    </span>\n    </h4>\n    <hr>\n\n    <h4 style=\"text-align: center\">\n    Abstract\n    </h4>\n\n    <p style=\"text-align: justify\">\n    The de-facto approach to many vision tasks is to start from pretrained\n    visual representations, typically learned via supervised training on\n    ImageNet. Recent methods have explored unsupervised pretraining to scale to\n    vast quantities of unlabeled images. In contrast, we aim to learn\n    high-quality visual representations from fewer images. To this end we\n    revisit supervised pretraining, and seek data-efficient alternatives to\n    classification-based pretraining. We propose VirTex -- a pretraining\n    approach using semantically dense captions to learn visual representations.\n    We train convolutional networks from scratch on COCO Captions, and transfer\n    them to downstream recognition tasks including image classification, object\n    detection, and instance segmentation. On all tasks, VirTex yields features\n    that match or exceed those learned on ImageNet -- supervised or unsupervised\n    -- despite using up to ten times fewer images.\n    </p>\n\n**CVPR 2021. Paper available at:** `arxiv.org/abs/2006.06666 <https://arxiv.org/abs/2006.06666>`_.\n\n**Code available at:** `github.com/kdexd/virtex <https://github.com/kdexd/virtex>`_.\n\n.. image:: _static/system_figure.jpg\n\n\nGet the pretrained ResNet-50 visual backbone from our best performing VirTex\nmodel in one line *without any installation*!\n\n.. code-block:: python\n\n    import torch\n\n    # That's it, this one line only requires PyTorch.\n    model = torch.hub.load(\"kdexd/virtex\", \"resnet50\", pretrained=True)\n\n\nMore details in :doc:`virtex/usage/model_zoo`. Next, dive deeper into our\ncode with User Guide and API References!\n\n\nUser Guide\n----------\n\n.. toctree::\n    :maxdepth: 2\n\n    virtex/usage/setup_dependencies\n    virtex/usage/model_zoo\n    virtex/usage/pretrain\n    virtex/usage/downstream\n\n\nAPI Reference\n-------------\n\n.. toctree::\n    :maxdepth: 2\n\n    virtex/config\n    virtex/factories\n    virtex/data\n    virtex/models\n    virtex/modules\n    virtex/optim\n    virtex/utils\n    virtex/model_zoo\n\n\nCitation\n--------\n\nIf you find this code useful, please consider citing:\n\n.. code-block:: text\n\n    @inproceedings{desai2021virtex,\n        title={{VirTex: Learning Visual Representations from Textual Annotations}},\n        author={Karan Desai and Justin Johnson},\n        booktitle={CVPR},\n        year={2021}\n    }\n\n\nAcknowledgments\n---------------\n\nWe thank Harsh Agrawal, Mohamed El Banani, Richard  Higgins, Nilesh Kulkarni\nand Chris Rockwell for helpful discussions and feedback on the paper. We thank\nIshan Misra for discussions regarding PIRL evaluation protocol; Saining Xie for\ndiscussions about replicating iNaturalist evaluation as MoCo; Ross Girshick and\nYuxin Wu for help with Detectron2 model zoo; Georgia Gkioxari for suggesting\nthe Instance Segmentation pretraining task ablation; and Stefan Lee for\nsuggestions on figure aesthetics. We thank Jia Deng for access to extra GPUs\nduring project development; and UMich ARC-TS team for support with GPU cluster\nmanagement. Finally, we thank all the Starbucks outlets in Ann Arbor for many\nhours of free WiFi. This work was partially supported by the Toyota Research\nInstitute (TRI). However, note that this article solely reflects the opinions\nand conclusions of its authors and not TRI or any other Toyota entity.\n\n\nIndices and Tables\n------------------\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/virtex/config.rst",
    "content": "virtex.config\n=============\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.config\n\n\nConfig References\n-----------------\n\n.. literalinclude:: ../../virtex/config.py\n  :language: python\n  :linenos:\n  :lines: 42-210\n  :dedent: 8\n"
  },
  {
    "path": "docs/virtex/data.datasets.rst",
    "content": "virtex.data.datasets\n====================\n\n.. raw:: html\n\n    <hr>\n\nPretraining Datasets\n--------------------\n\n.. automodule:: virtex.data.datasets.coco_captions\n\n.. automodule:: virtex.data.datasets.captioning\n\n.. automodule:: virtex.data.datasets.classification\n\n------------------------------------------------------------------------------\n\nDownstream Datasets\n-------------------\n\n.. automodule:: virtex.data.datasets.downstream\n"
  },
  {
    "path": "docs/virtex/data.rst",
    "content": "virtex.data\n===========\n\n.. raw:: html\n\n    <hr>\n\n\n.. toctree::\n\n    data.datasets\n    data.tokenizers\n    data.transforms\n"
  },
  {
    "path": "docs/virtex/data.tokenizers.rst",
    "content": "virtex.data.tokenizers\n======================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.data.tokenizers\n"
  },
  {
    "path": "docs/virtex/data.transforms.rst",
    "content": "virtex.data.transforms\n======================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.data.transforms\n"
  },
  {
    "path": "docs/virtex/factories.rst",
    "content": "virtex.factories\n================\n\n.. raw:: html\n\n    <hr>\n\n.. First only include the top-level module, and base class docstrings.\n\n.. automodule:: virtex.factories\n    :no-members:\n\n.. autoclass:: virtex.factories.Factory\n\n\n------------------------------------------------------------------------------\n\nDataloading-related Factories\n-----------------------------\n\n.. autoclass:: virtex.factories.TokenizerFactory\n    :members: from_config\n\n.. autoclass:: virtex.factories.ImageTransformsFactory\n    :members: from_config\n\n.. autoclass:: virtex.factories.PretrainingDatasetFactory\n    :members: from_config\n\n.. autoclass:: virtex.factories.DownstreamDatasetFactory\n    :members: from_config\n\n------------------------------------------------------------------------------\n\nModeling-related Factories\n--------------------------\n\n.. autoclass:: virtex.factories.VisualBackboneFactory\n    :members: from_config\n\n.. autoclass:: virtex.factories.TextualHeadFactory\n    :members: from_config\n\n.. autoclass:: virtex.factories.PretrainingModelFactory\n    :members: from_config\n\n------------------------------------------------------------------------------\n\nOptimization-related Factories\n------------------------------\n\n.. autoclass:: virtex.factories.OptimizerFactory\n    :members: from_config\n\n.. autoclass:: virtex.factories.LRSchedulerFactory\n    :members: from_config\n"
  },
  {
    "path": "docs/virtex/model_zoo.rst",
    "content": "virtex.model_zoo\n================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.model_zoo.model_zoo\n"
  },
  {
    "path": "docs/virtex/models.rst",
    "content": "virtex.models\n=============\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.models.classification\n\n-------------------------------------------------------------------------------\n\n.. automodule:: virtex.models.captioning\n\n-------------------------------------------------------------------------------\n\n.. automodule:: virtex.models.masked_lm\n"
  },
  {
    "path": "docs/virtex/modules.embedding.rst",
    "content": "virtex.modules.embedding\n========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.modules.embedding\n"
  },
  {
    "path": "docs/virtex/modules.rst",
    "content": "virtex.modules\n==============\n\n.. raw:: html\n\n    <hr>\n\n.. toctree::\n\n    modules.embedding\n    modules.visual_backbones\n    modules.textual_heads\n"
  },
  {
    "path": "docs/virtex/modules.textual_heads.rst",
    "content": "virtex.modules.textual_heads\n============================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.modules.textual_heads\n"
  },
  {
    "path": "docs/virtex/modules.visual_backbones.rst",
    "content": "virtex.modules.visual_backbones\n===============================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.modules.visual_backbones\n"
  },
  {
    "path": "docs/virtex/optim.lookahead.rst",
    "content": "virtex.optim.lookahead\n======================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.optim.lookahead\n"
  },
  {
    "path": "docs/virtex/optim.lr_scheduler.rst",
    "content": "virtex.optim.lr_scheduler\n=========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.optim.lr_scheduler\n"
  },
  {
    "path": "docs/virtex/optim.rst",
    "content": "virtex.optim\n============\n\n.. raw:: html\n\n    <hr>\n\n.. toctree::\n\n    optim.lookahead\n    optim.lr_scheduler\n"
  },
  {
    "path": "docs/virtex/usage/downstream.rst",
    "content": "How to evaluate on downstream tasks?\n====================================\n\nIn our paper, we evaluate our pretrained VirTex models on seven different\ndownstream tasks. Our codebase supports all of these evaluations. Throughout\nthis documentation, we consider a specific example of our VirTex pretrained\nmodel being evaluated for ensuring filepath uniformity in the following example\ncommand snippets. Paths can be trivially adjusted for any other VirTex model;\nevaluating the baselines (MoCo, ImageNet-supervised, Random Init) require\nadditional changes in commands, explained in the last sub-section.\n\nAs an example, consider a pretraining job for our best performing VirTex model\n(``width_ablations/bicaptioning_R_50_L1_H2048.yaml``). The serialization\ndirectory might look something like this:\n\n.. code-block:: text\n\n    /tmp/bicaptioning_R_50_L1_H2048\n        pretrain_config.yaml\n        log-rank0.txt    # stdout/stderr per GPU process\n        log-rank1.txt\n        ...\n        log-rank7.txt\n        checkpoint_2000.pth\n        checkpoint_4000.pth\n        ...\n        checkpoint_498000.pth\n        checkpoint_500000.pth    # serialized checkpoints\n        train_captioning_forward/\n            events.out.* ...    # tensorboard logs\n        ...\n\nWe evaluate all checkpoints on **PASCAL VOC 2007 Linear Classification**, and\nthen evaluate the best checkpoint (here, it was iteration 500000) on all other\ndownstream tasks.\n\n\nPASCAL VOC 2007 Linear Classification\n-------------------------------------\n\nEvaluate a single VirTex pretrained checkpoint on VOC 2007 ``trainval`` split:\n\n.. code-block:: shell\n\n    python scripts/clf_voc07.py \\\n        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \\\n        --down-config configs/downstream/voc07_clf.yaml \\\n        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \\\n        --weight-init virtex \\\n        --num-gpus-per-machine 1 \\\n        --cpu-workers 4 \\\n        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048\n\nTo evaluate recent 100 checkpoints in the sub-directory, this command can be\nlooped over as follows:\n\n.. code-block:: shell\n\n    for ((iter = 300000; iter <= 500000; iter+=2000)); do\n        # add command with `checkpoint_$iter.pth`        \n    done\n\nThis script write metric to tensorboard logs in the same pretraining directory,\nall VOC07 mAP curves appear together with pretraining loss curves.\n\n-------------------------------------------------------------------------------\n\nImageNet Linear Classification\n------------------------------\n\nWe train a linear classifier on 2048-dimensional global average pooled features\nextracted from a frozen visual backbone. Evaluate a checkpoint (for example,\niteration 500000) on this task as:\n\n.. code-block:: shell\n\n    python scripts/clf_linear.py \\\n        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \\\n        --down-config configs/downstream/imagenet_clf.yaml \\\n        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \\\n        --weight-init virtex \\\n        --num-gpus-per-machine 8 \\\n        --cpu-workers 4 \\\n        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/imagenet_500000 \\\n        --checkpoint-every 5005  # 1 epoch of ImageNet\n\n-------------------------------------------------------------------------------\n\nInstance Segmentation (and Object Detection) on COCO\n----------------------------------------------------\n\nTrain a Mask R-CNN with FPN backbone for COCO Instance Segmentation (and Object\nDetection, because it also has a box head) by initializing the backbone from\nVirTex pretrained weights:\n\n.. code-block:: shell\n\n    python scripts/eval_detectron2.py \\\n        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \\\n        --d2-config configs/detectron2/coco_segm_default_init_2x.yaml \\\n        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \\\n        --weight-init virtex \\\n        --num-gpus-per-machine 8 \\\n        --cpu-workers 2 \\\n        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/coco_segm_500000 \\\n        --checkpoint-every 5000\n\n.. note::\n\n    1. This script periodically serializes checkpoints but skips validation\n       step during training for saving time; to evaluate a serialized checkpoint\n       and write results to tensorboard, provide it as ``--checkpoint-path`` and\n       additional flags ``--resume --eval-only``.\n\n    2. Note that ``--d2-config`` here is in Detectron2 format, and not our\n       package :class:`~virtex.config.Config`.\n\n    These points are applicable for all tasks described below.\n\n-------------------------------------------------------------------------------\n\nInstance Segmentation on LVIS\n-----------------------------\n\nTrain a Mask R-CNN with FPN backbone for LVIS Instance Segmentation by\ninitializing the backbone from VirTex pretrained weights:\n\n.. code-block:: shell\n\n    python scripts/eval_detectron2.py \\\n        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \\\n        --d2-config configs/detectron2/lvis_segm_default_init_2x.yaml \\\n        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \\\n        --weight-init virtex \\\n        --num-gpus-per-machine 8 \\\n        --cpu-workers 2 \\\n        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/lvis_segm_500000 \\\n        --checkpoint-every 5000\n\n-------------------------------------------------------------------------------\n\nObject Detection on PASCAL VOC 2007+12\n--------------------------------------\n\nTrain a Faster R-CNN with C4 backbone for PASCAL VOC 2007+12 Object Detection\nby initializing the backbone from VirTex pretrained weights:\n\n.. code-block:: shell\n\n    python scripts/eval_detectron2.py \\\n        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \\\n        --d2-config configs/detectron2/voc_det_default_init_24k.yaml \\\n        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \\\n        --weight-init virtex \\\n        --num-gpus-per-machine 8 \\\n        --cpu-workers 2 \\\n        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/voc_det_500000 \\\n        --checkpoint-every 2500\n\n-------------------------------------------------------------------------------\n\niNaturalist 2018 Fine-Grained Classification\n--------------------------------------------\n\nFine-tune the VirTex pretrained visual backbone end-to-end on iNaturalist 2018\ndataset:\n\n.. code-block:: shell\n\n    python scripts/clf_linear.py \\\n        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \\\n        --down-config configs/downstream/inaturalist_clf.yaml \\\n        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \\\n        --weight-init virtex \\\n        --num-gpus-per-machine 8 \\\n        --cpu-workers 4 \\\n        --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/inaturalist_500000 \\\n        --checkpoint-every 1710  # 1 epoch of iNaturalist\n\n-------------------------------------------------------------------------------\n\nImage Captioning on COCO Captions val2017\n-----------------------------------------\n\nEvaluate a pretrained VirTex model on image captioning for COCO Captions val2017\nsplit (reporting CIDEr and SPICE metics):\n\n.. code-block:: shell\n\n    python scripts/eval_captioning.py \\\n        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \\\n        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \\\n        --calc-metrics \\\n        --num-gpus-per-machine 1 \\\n        --cpu-workers 4\n\n-------------------------------------------------------------------------------\n\nRunning Image Captioning Inference on Arbitrary Images\n------------------------------------------------------\n\nThe above script can be used for generating captions for any images in a directory.\nReplace certain commands as follows:\n\n.. code-block:: shell\n\n    python scripts/eval_captioning.py \\\n        --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \\\n        --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \\\n        --data-root /path/to/images_dir \\\n        --output /path/to/save/predictions.json \\\n        --num-gpus-per-machine 1 \\\n        --cpu-workers 4\n\nThis script will save predictions in JSON format. Since our goal is to not\nimprove image captioning, these models may not generate the best captions.\n"
  },
  {
    "path": "docs/virtex/usage/model_zoo.rst",
    "content": "VirTex Model Zoo\n================\n\nWe provide a collection of pretrained model weights and corresponding config\nnames in this model zoo. Tables contain partial paths to config files for each\nmodel, download link for pretrained weights and for reference -- VOC07 mAP and\nImageNet top-1 accuracy.\n\nThe simplest way to download and use a *full* pretrained model (including both,\nthe visual backbone and the textual head) is through :doc:`../model_zoo` API as\nfollows. This code snippet works from anywhere, and does not require to be\nexecuted from project root.\n\n.. code-block:: python\n\n    # Get our full best performing VirTex model:\n    import virtex.model_zoo as mz\n    model = mz.get(\"width_ablations/bicaptioning_R_50_L1_H2048.yaml\", pretrained=True)\n\n    # Optionally extract the torchvision-like visual backbone (with ``avgpool``\n    # and ``fc`` layers replaced with ``nn.Identity`` module).\n    cnn = model.visual.cnn\n\nAlternatively, weights can be manually downloaded from links below, and this\ncan be executed from the project root:\n\n.. code-block:: python\n\n    from virtex.config import Config\n    from virtex.factories import PretrainingModelFactory\n    from virtex.utils.checkpointing import CheckpointManager\n\n    # Get the best performing VirTex model:\n    _C = Config(\"configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml\")\n    model = PretrainingModelFactory.from_config(_C)\n\n    CheckpointManager(model=model).load(\"/path/to/downloaded/weights.pth\")\n\n    # Optionally extract the torchvision-like visual backbone (with ``avgpool``\n    # and ``fc`` layers replaced with ``nn.Identity`` module).\n    cnn = model.visual.cnn\n\n\nThe pretrained ResNet-50 visual backbone of our best performing model\n(``width_ablations/bicaptioning_R_50_L1_H2048.yaml``) can be loaded in a single\nline, *without following any installation steps* (only requires PyTorch v1.5):\n\n.. code-block:: python\n\n    import torch\n\n    model = torch.hub.load(\"kdexd/virtex\", \"resnet50\", pretrained=True)\n\n    # This is a torchvision-like resnet50 model, with ``avgpool`` and ``fc``\n    # layers replaced with ``nn.Identity`` module.\n    image_batch = torch.randn(1, 3, 224, 224)  # batch tensor of one image.\n    features_batch = model(image_batch)  # shape: (1, 2048, 7, 7)\n\n-------------------------------------------------------------------------------\n\nPretraining Task Ablations\n^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n.. raw:: html\n\n    <style type=\"text/css\">\n    .tg  {border-collapse:collapse;border-spacing:0;}\n    .tg td{border-color:black;border-style:solid;border-width:1px;\n    overflow:hidden;padding:10px 5px;word-break:normal;}\n    .tg th{border-color:black;border-style:solid;border-width:1px;\n    font-weight:normal;overflow:hidden;padding:10px 5px;word-break:normal;}\n    .tg .tg-zlqz{background-color:#d5d5d5;border-color:inherit;font-weight:bold;text-align:center;vertical-align:center}\n    .tg .tg-c3ow{border-color:inherit;text-align:center;vertical-align:top}\n    .tg .tg-c3ow a{color: darkgreen; text-decoration: none; border-bottom: 1px dashed green;text-underline-position: under;\n    .tg .tg-c3ow a:hover{font-weight: 700;border-bottom: 1px solid green;}\n    .tg .tg-0pky{border-color:inherit;text-align:left;vertical-align:top}\n    @media screen and (max-width: 767px) {.tg {width: auto !important;}.tg col {width: auto !important;}.tg-wrap {overflow-x: auto;-webkit-overflow-scrolling: touch;}}</style>\n    <div class=\"tg-wrap\"><table class=\"tg\">\n    <tbody>\n    <tr>\n        <td class=\"tg-zlqz\">Model Config Name</td>\n        <td class=\"tg-zlqz\">VOC07<br>mAP</td>\n        <td class=\"tg-zlqz\">ImageNet<br>Top-1 Acc.</td>\n        <td class=\"tg-zlqz\">Model URL</td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">task_ablations/bicaptioning_R_50_L1_H2048.yaml</td>\n        <td class=\"tg-c3ow\">88.7</td>\n        <td class=\"tg-c3ow\">53.8</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/mbeeso8wyieq8wy/bicaptioning_R_50_L1_H2048.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">task_ablations/captioning_R_50_L1_H2048.yaml</td>\n        <td class=\"tg-c3ow\">88.6</td>\n        <td class=\"tg-c3ow\">50.8</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/r6zen9k43m5oo58/captioning_R_50_L1_H2048.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">task_ablations/token_classification_R_50.yaml</td>\n        <td class=\"tg-c3ow\">88.8</td>\n        <td class=\"tg-c3ow\">48.6</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/o4p9lki505r0mef/token_classification_R_50.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">task_ablations/multilabel_classification_R_50.yaml</td>\n        <td class=\"tg-c3ow\">86.2</td>\n        <td class=\"tg-c3ow\">46.2</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/hbspp3jv3u8h3bc/multilabel_classification_R_50.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">task_ablations/masked_lm_R_50_L1_H2048.yaml</td>\n        <td class=\"tg-c3ow\">86.4</td>\n        <td class=\"tg-c3ow\">46.7</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/ldzrk6vem4mg6bl/masked_lm_R_50_L1_H2048.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    </tbody>\n    </table></div>\n\n\nWidth Ablations\n^^^^^^^^^^^^^^^\n\n.. raw:: html\n\n    <div class=\"tg-wrap\"><table class=\"tg\">\n    <tbody>\n    <tr>\n        <td class=\"tg-zlqz\">Model Config Name</td>\n        <td class=\"tg-zlqz\">VOC07<br>mAP</td>\n        <td class=\"tg-zlqz\">ImageNet<br>Top-1 Acc.</td>\n        <td class=\"tg-zlqz\">Model URL</td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">width_ablations/bicaptioning_R_50_L1_H512.yaml</td>\n        <td class=\"tg-c3ow\">88.4</td>\n        <td class=\"tg-c3ow\">51.8</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/o9fr69jjqfn8a65/bicaptioning_R_50_L1_H512.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\"><span style=\"font-weight:400;font-style:normal\">width_ablations/bicaptioning_R_50_L1_H768.yaml</span></td>\n        <td class=\"tg-c3ow\">88.3</td>\n        <td class=\"tg-c3ow\">52.3</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/1zxglqrrbfufv9d/bicaptioning_R_50_L1_H768.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\"><span style=\"font-weight:400;font-style:normal\">width_ablations/bicaptioning_R_50_L1_H1024.yaml</span></td>\n        <td class=\"tg-c3ow\">88.3</td>\n        <td class=\"tg-c3ow\">53.2</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/pdat4tvhnqxel64/bicaptioning_R_50_L1_H1024.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\"><span style=\"font-weight:400;font-style:normal\">width_ablations/bicaptioning_R_50_L1_H2048.yaml</span></td>\n        <td class=\"tg-c3ow\">88.7</td>\n        <td class=\"tg-c3ow\">53.8</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/mbeeso8wyieq8wy/bicaptioning_R_50_L1_H2048.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    </tbody>\n    </table></div>\n\n\nDepth Ablations\n^^^^^^^^^^^^^^^\n\n.. raw:: html\n\n    <div class=\"tg-wrap\"><table class=\"tg\">\n    <tbody>\n    <tr>\n        <td class=\"tg-zlqz\">Model Config Name</td>\n        <td class=\"tg-zlqz\">VOC07<br>mAP</td>\n        <td class=\"tg-zlqz\">ImageNet<br>Top-1 Acc.</td>\n        <td class=\"tg-zlqz\">Model URL</td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">depth_ablations/bicaptioning_R_50_L1_H1024.yaml</td>\n        <td class=\"tg-c3ow\">88.3</td>\n        <td class=\"tg-c3ow\">53.2</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/pdat4tvhnqxel64/bicaptioning_R_50_L1_H1024.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">depth_ablations/bicaptioning_R_50_L2_H1024.yaml</td>\n        <td class=\"tg-c3ow\">88.8</td>\n        <td class=\"tg-c3ow\">53.8</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/ft1vtt4okirzjgo/bicaptioning_R_50_L2_H1024.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\"><span style=\"font-weight:400;font-style:normal\">depth_ablations/bicaptioning_R_50_L3_H1024.yaml</span></td>\n        <td class=\"tg-c3ow\">88.7</td>\n        <td class=\"tg-c3ow\">53.9</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/5ldo1rcsnrshmjr/bicaptioning_R_50_L3_H1024.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\"><span style=\"font-weight:400;font-style:normal\">depth_ablations/bicaptioning_R_50_L4_H1024.yaml</span></td>\n        <td class=\"tg-c3ow\">88.7</td>\n        <td class=\"tg-c3ow\">53.9</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/zgiit2wcluuq3xh/bicaptioning_R_50_L4_H1024.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    </tbody>\n    </table></div>\n\n\nBackbone Ablations\n^^^^^^^^^^^^^^^^^^\n\n.. raw:: html\n\n    <div class=\"tg-wrap\"><table class=\"tg\">\n    <tbody>\n    <tr>\n        <td class=\"tg-zlqz\">Model Config Name</td>\n        <td class=\"tg-zlqz\">VOC07<br>mAP</td>\n        <td class=\"tg-zlqz\">ImageNet<br>Top-1 Acc.</td>\n        <td class=\"tg-zlqz\">Model URL</td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">backbone_ablations/bicaptioning_R_50_L1_H1024.yaml</td>\n        <td class=\"tg-c3ow\">88.3</td>\n        <td class=\"tg-c3ow\">53.2</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/pdat4tvhnqxel64/bicaptioning_R_50_L1_H1024.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml</td>\n        <td class=\"tg-c3ow\">88.5</td>\n        <td class=\"tg-c3ow\">52.9</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/5o198ux709r6376/bicaptioning_R_50W2X_L1_H1024.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    <tr>\n        <td class=\"tg-0pky\">backbone_ablations/bicaptioning_R_101_L1_H1024.yaml</td>\n        <td class=\"tg-c3ow\">88.7</td>\n        <td class=\"tg-c3ow\">52.1</td>\n        <td class=\"tg-c3ow\"><a href=\"https://www.dropbox.com/s/bb74jubt68cpn80/bicaptioning_R_101_L1_H1024.pth?dl=0\" target=\"_blank\" rel=\"noopener noreferrer\">model</a></td>\n    </tr>\n    </tbody>\n    </table></div>\n"
  },
  {
    "path": "docs/virtex/usage/pretrain.rst",
    "content": "How to train your VirTex model?\n===============================\n\nWe provide training scripts for all type of VirTex models from the paper;\nincluding our best-performing model and other ablations.\nOur training jobs are specified by config files (YAML).\nExecute all commands from project root to use the provided config files.\n\n\nTraining the base VirTex model\n------------------------------\n\nTrain the base VirTex model with ResNet-50 visual backbone; and a textual head\nwith ``L = 1, H = 1024`` using all default optimization hyperparameters.\n\n.. code-block::\n\n    python scripts/pretrain_virtex.py \\\n        --config configs/_base_bicaptioning_R_50_L1_H1024.yaml \\\n        --num-gpus-per-machine 8 \\\n        --cpu-workers 4 \\\n        --serialization-dir /tmp/VIRTEX_R_50_L1_H1024\n        # Default: --checkpoint-every 2000 --log-every 20\n\nTraining job will save checkpoints, tensorboard logs (loss curves and metrics),\nand back up the config in ``--serialization-dir``. Use ``tensorboard --logdir\n<serialization_dir>`` to view training curves, validation metrics etc. directly\non tensorboard.\n\nWe recommend training with 8 GPUs on the same machine, although training with\nmultiple GPUs across machines (see: ``--num-machines`` and ``--machine-rank``),\nsingle GPU (``--num-gpus-per-machine 1``) as well as CPU\n(``--num-gpus-per-machine 0``) is also supported. Using multiple GPUs for\ninteractive debugging with PDB is not supported, as PDB and ``multiprocessing``\nmodule do not play nice.\n\n-------------------------------------------------------------------------------\n\nReproducing all VirTex ablations\n--------------------------------\n\nTo reproduce all ablations from the `paper <https://arxiv.org/abs/2006.06666>`_,\nreplace the ``--config`` argument in above command with the following (all\nassumed to be relative to project root):\n\nPretraining Task Ablations\n^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n1. **Bicaptioning:** configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml\n2. **Forward Captioning:** configs/task_ablations/captioning_R_50_L1_H2048.yaml\n3. **Token Classification:** configs/task_ablations/token_classification_R_50.yaml\n4. **Multilabel Classification:** configs/task_ablations/multilabel_classification_R_50.yaml\n5. **Masked Language Modeling:** configs/task_ablations/masked_lm_R_50_L1_H2048.yaml\n\nTransformer Size Ablations\n^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n1. **Width (H = 512):** configs/width_ablations/bicaptioning_R_50_L1_H512.yaml\n2. **Width (H = 768):** configs/width_ablations/bicaptioning_R_50_L1_H768.yaml\n3. **Width (H = 1024):** configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml\n4. **Width (H = 2048):** configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml\n5. **Depth (L = 1):** configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml\n6. **Depth (L = 2):** configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml\n7. **Depth (L = 3):** configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml\n8. **Depth (L = 4):** configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml\n\nBackbone Ablations\n^^^^^^^^^^^^^^^^^^\n\n1. **ResNet-50:** configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml\n2. **ResNet-50 w2x:** configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml\n3. **ResNet-101:** configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml\n\n.. note::\n\n    **Pretraining Task Ablations** (1), **Transformer Size Ablations** (3 and 5)\n    and **Backbone Ablations** (1) are all the same exact model.\n"
  },
  {
    "path": "docs/virtex/usage/setup_dependencies.rst",
    "content": "How to setup this codebase?\n===========================\n\n.. raw:: html\n\n    <hr>\n\nThis codebase requires Python 3.6+ or higher. We recommend using Anaconda or\nMiniconda. We walk through installation and data preprocessing here.\n\n\nInstall Dependencies\n--------------------\n\nFor these steps to install through Anaconda (or Miniconda).\n\n1. Install Anaconda or Miniconda distribution based on Python 3+ from their\n   `downloads site <https://conda.io/docs/user-guide/install/download.html>`_.\n\n\n2. Clone the repository first.\n\n    .. code-block:: shell\n\n        git clone https://www.github.com/kdexd/virtex\n\n\n3. Create a conda environment and install all the dependencies.\n\n    .. code-block:: shell\n\n        cd virtex\n        conda create -n virtex python=3.8\n        conda activate virtex\n        pip install -r requirements.txt\n\n\n4. Install additional packages from Github.\n\n    .. code-block:: shell\n\n        pip install git+git://github.com/facebookresearch/fvcore.git#egg=fvcore\n        pip install git+git://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI\n\n\n5. Install this codebase as a package in development version.\n\n    .. code-block:: shell\n\n        python setup.py develop\n\nNow you can ``import virtex`` from anywhere as long as you have this conda\nenvironment activated.\n\n-------------------------------------------------------------------------------\n\n\nSetup Datasets\n--------------\n\nDatasets are assumed to exist in ``./datasets`` directory (relative to the\nproject root) following the structure specified below. COCO is used for\npretraining, and rest of the datasets (including COCO) are used for downstream\ntasks. This structure is compatible when using\n`Detectron2 <https://github.com/facebookresearch/detectron2>`_ for downstream\ntasks.\n\nCOCO\n^^^^\n.. code-block::\n\n    datasets/coco/\n        annotations/\n            captions_{train,val}2017.json\n            instances_{train,val}2017.json\n        train2017/\n            # images in train2017 split\n        val2017/\n            # images in val2017 split\n\nLVIS\n^^^^\n.. code-block::\n\n    datasets/coco/\n        train2017/\n        val2017/\n    datasets/lvis/\n        lvis_v1.0_{train,val}.json\n\nPASCAL VOC\n^^^^^^^^^^\n.. code-block::\n\n    datasets/VOC2007/\n        Annotations/\n        ImageSets/\n            Main/\n                trainval.txt\n                test.txt\n        JPEGImages/\n\n    datasets/VOC2012/\n        # Same as VOC2007 above\n\nImageNet\n^^^^^^^^\n.. code-block::\n\n    datasets/imagenet/\n        train/\n            # One directory per category with images in it\n        val/\n            # One directory per category with images in it\n        ILSVRC2012_devkit_t12.tar.gz\n\niNaturalist 2018\n^^^^^^^^^^^^^^^^\n.. code-block::\n\n    datasets/inaturalist/\n        train_val2018/\n        annotations/\n            train2018.json\n            val2018.json\n\n-------------------------------------------------------------------------------\n\n\nBuild vocabulary\n----------------\n\nBuild a vocabulary out of COCO Captions ``train2017`` split.\n\n    .. code-block:: shell\n\n        python scripts/build_vocabulary.py \\\n            --captions datasets/coco/annotations/captions_train2017.json \\\n            --vocab-size 10000 \\\n            --output-prefix datasets/vocab/coco_10k \\\n            --do-lower-case\n\nThat's it! You are all set to use this codebase.\n"
  },
  {
    "path": "docs/virtex/utils.beam_search.rst",
    "content": "virtex.utils.beam_search\n========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.beam_search\n"
  },
  {
    "path": "docs/virtex/utils.checkpointing.rst",
    "content": "virtex.utils.checkpointing\n==========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.checkpointing\n"
  },
  {
    "path": "docs/virtex/utils.common.rst",
    "content": "virtex.utils.common\n===================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.common\n"
  },
  {
    "path": "docs/virtex/utils.distributed.rst",
    "content": "virtex.utils.distributed\n========================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.distributed\n"
  },
  {
    "path": "docs/virtex/utils.metrics.rst",
    "content": "virtex.utils.metrics\n====================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.metrics\n"
  },
  {
    "path": "docs/virtex/utils.rst",
    "content": "virtex.utils\n============\n\n.. raw:: html\n\n    <hr>\n\n.. toctree::\n\n    utils.common\n    utils.distributed\n    utils.timer\n    utils.checkpointing\n    utils.beam_search\n    utils.metrics\n"
  },
  {
    "path": "docs/virtex/utils.timer.rst",
    "content": "virtex.utils.timer\n==================\n\n.. raw:: html\n\n    <hr>\n\n.. automodule:: virtex.utils.timer\n"
  },
  {
    "path": "hubconf.py",
    "content": "dependencies = [\"torch\"]\n\nimport torch\nimport torchvision\n\n\nR50_URL = \"https://www.dropbox.com/s/pxgjxcva7oypf12/backbone_bicaptioning_R_50_L1_H2048.pth?dl=1\"\n\n\ndef resnet50(pretrained: bool = False, **kwargs):\n    r\"\"\"\n    ResNet-50 visual backbone from the best performing VirTex model: pretrained\n    for bicaptioning on COCO Captions, with textual head ``L = 1, H = 2048``.\n\n    This is a torchvision-like model, with the last ``avgpool`` and `fc``\n    modules replaced with ``nn.Identity()`` modules. Given a batch of image\n    tensors with size ``(B, 3, 224, 224)``, this model computes spatial image\n    features of size ``(B, 7, 7, 2048)``, where B = batch size.\n\n    pretrained (bool): Whether to load model with pretrained weights.\n    \"\"\"\n\n    # Create a torchvision resnet50 with randomly initialized weights.\n    model = torchvision.models.resnet50(pretrained=False, **kwargs)\n\n    # Replace global average pooling and fully connected layers with identity\n    # modules.\n    model.avgpool = torch.nn.Identity()\n    model.fc = torch.nn.Identity()\n\n    if pretrained:\n        model.load_state_dict(\n            torch.hub.load_state_dict_from_url(R50_URL, progress=False)\n        )\n    return model\n"
  },
  {
    "path": "requirements.txt",
    "content": "albumentations>=1.0\nCython>=0.25\nfuture==0.18.0\nloguru>=0.3\nlvis>=0.5\nnumpy>=1.17\nopencv-python>=4.2.0\nscikit-learn>=1.0\nsentencepiece>=0.1.90\ntorch>=1.9\ntorchvision>=0.10\ntqdm>=4.50.0\n"
  },
  {
    "path": "scripts/build_vocabulary.py",
    "content": "import argparse\nimport json\nimport os\nimport tempfile\nimport unicodedata\nfrom typing import List\n\nimport sentencepiece as sp\n\n\n# fmt: off\nparser = argparse.ArgumentParser(\n    description=\"\"\"Build a vocabulary out of captions corpus. This vocabulary\n    would be a file which our tokenizer can understand.\n    \"\"\"\n)\nparser.add_argument(\n    \"-c\", \"--captions\", default=\"datasets/coco/annotations/captions_train2017.json\",\n    help=\"Path to caption annotations file in COCO format.\",\n)\nparser.add_argument(\n    \"-s\", \"--vocab-size\", type=int, default=10000,\n    help=\"Total desired size of our vocabulary.\",\n)\nparser.add_argument(\n    \"-o\", \"--output-prefix\", default=\"datasets/vocab/coco_10k\",\n    help=\"Prefix of the files to be saved. Two files will be saved: \"\n    \"[prefix].model and [prefix].vocab\",\n)\nparser.add_argument(\n    \"-l\", \"--do-lower-case\", action=\"store_true\",\n    help=\"Whether to lower case the captions before forming vocabulary.\",\n)\nparser.add_argument(\n    \"-a\", \"--keep-accents\", action=\"store_true\",\n    help=\"Whether to keep accents before forming vocabulary (dropped by default).\",\n)\n# fmt: on\n\n\ndef _read_captions(annotations_path: str) -> List[str]:\n    r\"\"\"\n    Given a path to annotation file, read it and return a list of captions.\n    These are not processed by any means, returned from the file as-is.\n\n    Args:\n        annotations_path: Path to an annotations file containing captions.\n\n    Returns:\n        List of captions from this annotation file.\n    \"\"\"\n\n    _annotations = json.load(open(annotations_path))\n\n    captions: List[str] = []\n    for ann in _annotations[\"annotations\"]:\n        captions.append(ann[\"caption\"])\n\n    return captions\n\n\nif __name__ == \"__main__\":\n    _A = parser.parse_args()\n    captions: List[str] = _read_captions(_A.captions)\n\n    # Lower case the captions and remove accents according to arguments.\n    for i, caption in enumerate(captions):\n        caption = caption.lower() if _A.do_lower_case else caption\n\n        if not _A.keep_accents:\n            caption = unicodedata.normalize(\"NFKD\", caption)\n            caption = \"\".join(\n                [chr for chr in caption if not unicodedata.combining(chr)]\n            )\n\n        captions[i] = caption\n\n    # Create a temporary directory and dump the captions corpus as a text file\n    # with one caption per line. That's how sentencepiece wants its input.\n    tmpdir_path = tempfile.mkdtemp()\n\n    with open(os.path.join(tmpdir_path, \"captions.txt\"), \"w\") as captions_file:\n        for caption in captions:\n            captions_file.write(caption + \"\\n\")\n\n    # Padding/out-of-vocab token will be \"<unk>\" and ID 0 by default.\n    # Add [SOS],[EOS] and [MASK] tokens. [MASK] will not be used during\n    # captioning, but good to have to reuse vocabulary across pretext tasks.\n    sp.SentencePieceTrainer.train(\n        f\" --input={os.path.join(tmpdir_path, 'captions.txt')}\"\n        f\" --vocab_size={_A.vocab_size}\"\n        f\" --model_prefix={_A.output_prefix}\"\n        \" --model_type=bpe --character_coverage=1.0\"\n        \" --bos_id=-1 --eos_id=-1\"\n        \" --control_symbols=[SOS],[EOS],[MASK]\"\n    )\n"
  },
  {
    "path": "scripts/clf_linear.py",
    "content": "import argparse\nimport os\n\nfrom loguru import logger\nimport torch\nfrom torch import nn\nfrom torch.cuda import amp\nfrom torch.utils.data import DataLoader, DistributedSampler\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom virtex.config import Config\nfrom virtex.factories import (\n    DownstreamDatasetFactory,\n    PretrainingModelFactory,\n    OptimizerFactory,\n    LRSchedulerFactory,\n)\nfrom virtex.utils.checkpointing import CheckpointManager\nfrom virtex.utils.common import common_parser, common_setup, cycle\nimport virtex.utils.distributed as dist\nfrom virtex.utils.metrics import TopkAccuracy\nfrom virtex.utils.timer import Timer\n\n\n# fmt: off\nparser = common_parser(\n    description=\"\"\"Do image classification with linear models and frozen\n    feature extractor, or fine-tune the feature extractor end-to-end.\"\"\"\n)\ngroup = parser.add_argument_group(\"Downstream config arguments.\")\ngroup.add_argument(\n    \"--down-config\", metavar=\"FILE\", help=\"Path to a downstream config file.\"\n)\ngroup.add_argument(\n    \"--down-config-override\", nargs=\"*\", default=[],\n    help=\"A list of key-value pairs to modify downstream config params.\",\n)\n\nparser.add_argument_group(\"Checkpointing and Logging\")\nparser.add_argument(\n    \"--weight-init\", choices=[\"random\", \"imagenet\", \"torchvision\", \"virtex\"],\n    default=\"virtex\", help=\"\"\"How to initialize weights:\n        1. 'random' initializes all weights randomly\n        2. 'imagenet' initializes backbone weights from torchvision model zoo\n        3. {'torchvision', 'virtex'} load state dict from --checkpoint-path\n            - with 'torchvision', state dict would be from PyTorch's training\n              script.\n            - with 'virtex' it should be for our full pretrained model.\"\"\"\n)\nparser.add_argument(\n    \"--log-every\", type=int, default=50,\n    help=\"\"\"Log training curves to tensorboard after every these many iterations\n    only master process logs averaged loss values across processes.\"\"\",\n)\nparser.add_argument(\n    \"--checkpoint-path\",\n    help=\"\"\"Path to load checkpoint and run downstream task evaluation. The\n    name of checkpoint file is required to be `model_*.pth`, where * is\n    iteration number from which the checkpoint was serialized.\"\"\"\n)\nparser.add_argument(\n    \"--checkpoint-every\", type=int, default=5000,\n    help=\"\"\"Serialize model to a checkpoint after every these many iterations.\n    For ImageNet, (5005 iterations = 1 epoch); for iNaturalist (1710 iterations\n    = 1 epoch).\"\"\",\n)\n# fmt: on\n\n\ndef main(_A: argparse.Namespace):\n\n    if _A.num_gpus_per_machine == 0:\n        # Set device as CPU if num_gpus_per_machine = 0.\n        device = torch.device(\"cpu\")\n    else:\n        # Get the current device as set for current distributed process.\n        # Check `launch` function in `virtex.utils.distributed` module.\n        device = torch.cuda.current_device()\n\n    # Create a downstream config object (this will be immutable) and perform\n    # common setup such as logging and setting up serialization directory.\n    _DOWNC = Config(_A.down_config, _A.down_config_override)\n    common_setup(_DOWNC, _A, job_type=\"downstream\")\n\n    # Create a (pretraining) config object and backup in serializaion directory.\n    _C = Config(_A.config, _A.config_override)\n    _C.dump(os.path.join(_A.serialization_dir, \"pretrain_config.yaml\"))\n\n    # Get dataset name for tensorboard logging.\n    DATASET = _DOWNC.DATA.ROOT.split(\"/\")[-1]\n\n    # Set number of output classes according to dataset:\n    NUM_CLASSES_MAPPING = {\"imagenet\": 1000, \"inaturalist\": 8142}\n    NUM_CLASSES = NUM_CLASSES_MAPPING[DATASET]\n\n    # -------------------------------------------------------------------------\n    #   INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER\n    # -------------------------------------------------------------------------\n    train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split=\"train\")\n    train_dataloader = DataLoader(\n        train_dataset,\n        batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(),\n        num_workers=_A.cpu_workers,\n        sampler=DistributedSampler(\n            train_dataset,\n            num_replicas=dist.get_world_size(),\n            rank=dist.get_rank(),\n            shuffle=True,\n        ),\n        drop_last=False,\n        pin_memory=True,\n        collate_fn=train_dataset.collate_fn,\n    )\n    val_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split=\"val\")\n    val_dataloader = DataLoader(\n        val_dataset,\n        batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(),\n        num_workers=_A.cpu_workers,\n        sampler=DistributedSampler(\n            val_dataset,\n            num_replicas=dist.get_world_size(),\n            rank=dist.get_rank(),\n            shuffle=False,\n        ),\n        pin_memory=True,\n        drop_last=False,\n        collate_fn=val_dataset.collate_fn,\n    )\n    # Initialize model using pretraining config.\n    pretrained_model = PretrainingModelFactory.from_config(_C)\n\n    # Load weights according to the init method, do nothing for `random`, and\n    # `imagenet` is already taken care of.\n    if _A.weight_init == \"virtex\":\n        CheckpointManager(model=pretrained_model).load(_A.checkpoint_path)\n    elif _A.weight_init == \"torchvision\":\n        # Keep strict=False because this state dict may have weights for\n        # last fc layer.\n        pretrained_model.visual.cnn.load_state_dict(\n            torch.load(_A.checkpoint_path, map_location=\"cpu\")[\"state_dict\"],\n            strict=False,\n        )\n\n    # Pull out the CNN (torchvision-like) from our pretrained model and add\n    # back the FC layer - this is exists in torchvision models, and is set to\n    # `nn.Identity()` during pretraining.\n    model = pretrained_model.visual.cnn  # type: ignore\n    model.fc = nn.Linear(_DOWNC.MODEL.VISUAL.FEATURE_SIZE, NUM_CLASSES).to(device)\n    model = model.to(device)\n\n    # Re-initialize the FC layer.\n    torch.nn.init.normal_(model.fc.weight.data, mean=0.0, std=0.01)\n    torch.nn.init.constant_(model.fc.bias.data, 0.0)\n\n    # Freeze all layers except FC as per config param.\n    if _DOWNC.MODEL.VISUAL.FROZEN:\n        # Set model to eval mode to prevent BatchNorm from updating running\n        # mean and std. With only a linear layer, being in eval mode when\n        # training will not matter anyway.\n        model.eval()\n\n        for name, param in model.named_parameters():\n            if \"fc\" not in name:\n                param.requires_grad = False\n\n    # Cross entropy loss and accuracy meter.\n    criterion = nn.CrossEntropyLoss()\n    top1 = TopkAccuracy(k=1)\n\n    optimizer = OptimizerFactory.from_config(_DOWNC, model.named_parameters())\n    scheduler = LRSchedulerFactory.from_config(_DOWNC, optimizer)\n    del pretrained_model\n\n    # -------------------------------------------------------------------------\n    #  BEFORE TRAINING STARTS\n    # -------------------------------------------------------------------------\n\n    # Create a gradient scaler for automatic mixed precision.\n    scaler = amp.GradScaler(enabled=_DOWNC.AMP)\n\n    # Create an iterator from dataloader to sample batches perpetually.\n    train_dataloader_iter = cycle(train_dataloader, device)\n\n    if dist.get_world_size() > 1:\n        dist.synchronize()\n        model = nn.parallel.DistributedDataParallel(\n            model, device_ids=[device], find_unused_parameters=True\n        )\n\n    if dist.is_master_process():\n        checkpoint_manager = CheckpointManager(\n            _A.serialization_dir,\n            model=model,\n            optimizer=optimizer,\n            scheduler=scheduler,\n        )\n        tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)\n\n    # Keep track of time per iteration and ETA.\n    timer = Timer(start_from=1, total_iterations=_DOWNC.OPTIM.NUM_ITERATIONS)\n\n    # -------------------------------------------------------------------------\n    #   TRAINING LOOP\n    # -------------------------------------------------------------------------\n    for iteration in range(1, _DOWNC.OPTIM.NUM_ITERATIONS + 1):\n        timer.tic()\n        optimizer.zero_grad()\n        batch = next(train_dataloader_iter)\n\n        with amp.autocast(enabled=_DOWNC.AMP):\n            logits = model(batch[\"image\"])\n            loss = criterion(logits, batch[\"label\"])\n\n        scaler.scale(loss).backward()\n        scaler.step(optimizer)\n        scaler.update()\n\n        scheduler.step()\n        timer.toc()\n\n        if iteration % _A.log_every == 0 and dist.is_master_process():\n            logger.info(\n                f\"{timer.stats} | Loss: {loss:.3f} | GPU: {dist.gpu_mem_usage()} MB\"\n            )\n            tensorboard_writer.add_scalar(f\"{DATASET}/train_loss\", loss, iteration)\n            tensorboard_writer.add_scalar(\n                f\"{DATASET}/learning_rate\",\n                optimizer.param_groups[0][\"lr\"],\n                iteration,\n            )\n\n        # ---------------------------------------------------------------------\n        #   VALIDATION\n        # ---------------------------------------------------------------------\n        if iteration % _A.checkpoint_every == 0:\n            torch.set_grad_enabled(False)\n            model.eval()\n\n            total_val_loss = torch.tensor(0.0).to(device)\n\n            for val_iteration, batch in enumerate(val_dataloader, start=1):\n                for key in batch:\n                    batch[key] = batch[key].to(device)\n\n                logits = model(batch[\"image\"])\n                loss = criterion(logits, batch[\"label\"])\n                _ = top1(logits, batch[\"label\"])\n                total_val_loss += loss\n\n            # Divide each loss component by number of val batches per GPU.\n            total_val_loss = total_val_loss / val_iteration\n            dist.average_across_processes(total_val_loss)\n\n            # Get accumulated Top-1 accuracy for logging across GPUs.\n            acc = top1.get_result()\n            top1.reset()\n            dist.average_across_processes(acc)\n\n            torch.set_grad_enabled(True)\n\n            # Set model back to train mode only when fine-tuning end-to-end.\n            if not _DOWNC.MODEL.VISUAL.FROZEN:\n                model.train()\n\n            # Save recent checkpoint and best checkpoint based on accuracy.\n            if dist.is_master_process():\n                checkpoint_manager.step(iteration)\n\n                logger.info(f\"Iter: {iteration} | Top-1 accuracy: {acc})\")\n                tensorboard_writer.add_scalar(\n                    f\"{DATASET}/val_loss\", total_val_loss, iteration\n                )\n                # This name scoping will result in Tensorboard displaying all\n                # metrics (VOC07, caption, etc.) together.\n                tensorboard_writer.add_scalars(\n                    f\"metrics/{DATASET}\", {\"top1\": acc}, iteration\n                )\n\n        # All processes will wait till master process is done logging.\n        dist.synchronize()\n\n\nif __name__ == \"__main__\":\n    _A = parser.parse_args()\n\n    # Add an arg in config override if `--weight-init` is imagenet.\n    if _A.weight_init == \"imagenet\":\n        _A.config_override.extend([\"MODEL.VISUAL.PRETRAINED\", True])\n\n    if _A.num_gpus_per_machine == 0:\n        main(_A)\n    else:\n        # This will launch `main` and set appropriate CUDA device (GPU ID) as\n        # per process (accessed in the beginning of `main`).\n        dist.launch(\n            main,\n            num_machines=_A.num_machines,\n            num_gpus_per_machine=_A.num_gpus_per_machine,\n            machine_rank=_A.machine_rank,\n            dist_url=_A.dist_url,\n            args=(_A,),\n        )\n"
  },
  {
    "path": "scripts/clf_voc07.py",
    "content": "import argparse\nimport multiprocessing as mp\nimport os\nfrom typing import Any, List\n\nimport numpy as np\nimport torch\nfrom loguru import logger\nfrom sklearn.svm import LinearSVC\nfrom sklearn.metrics import average_precision_score\nfrom sklearn.model_selection import cross_val_score\nfrom torch.nn import functional as F\nfrom torch.utils.data import DataLoader\nfrom torch.utils.tensorboard import SummaryWriter\nfrom tqdm import tqdm\n\nfrom virtex.config import Config\nfrom virtex.factories import PretrainingModelFactory, DownstreamDatasetFactory\nfrom virtex.utils.checkpointing import CheckpointManager\nfrom virtex.utils.common import common_parser, common_setup\n\n\nparser = common_parser(\n    description=\"Train SVMs for VOC2007 classification on a pretrained model.\"\n)\ngroup = parser.add_argument_group(\"Downstream config arguments.\")\ngroup.add_argument(\n    \"--down-config\", metavar=\"FILE\", help=\"Path to a downstream config file.\"\n)\ngroup.add_argument(\n    \"--down-config-override\",\n    nargs=\"*\",\n    default=[],\n    help=\"A list of key-value pairs to modify downstream config params.\",\n)\n\n# fmt: off\nparser.add_argument_group(\"Checkpointing\")\nparser.add_argument(\n    \"--weight-init\", choices=[\"random\", \"imagenet\", \"torchvision\", \"virtex\"],\n    default=\"virtex\", help=\"\"\"How to initialize weights:\n        1. 'random' initializes all weights randomly\n        2. 'imagenet' initializes backbone weights from torchvision model zoo\n        3. {'torchvision', 'virtex'} load state dict from --checkpoint-path\n            - with 'torchvision', state dict would be from PyTorch's training\n              script.\n            - with 'virtex' it should be for our full pretrained model.\"\"\"\n)\nparser.add_argument(\n    \"--checkpoint-path\",\n    help=\"Path to load checkpoint and run downstream task evaluation.\"\n)\n# fmt: on\n\n\ndef train_test_single_svm(args):\n\n    feats_train, tgts_train, feats_test, tgts_test, cls_name = args\n    SVM_COSTS = [0.01, 0.1, 1.0, 10.0]\n\n    cls_labels = np.copy(tgts_train)\n    # Meaning of labels in VOC/COCO original loaded target files:\n    # label 0 = not present, set it to -1 as svm train target\n    # label 1 = present. Make the svm train target labels as -1, 1.\n    cls_labels[np.where(cls_labels == 0)] = -1\n\n    # See which cost maximizes the AP for this class.\n    best_crossval_ap: float = 0.0\n    best_crossval_clf = None\n    best_cost: float = 0.0\n\n    # fmt: off\n    for cost in SVM_COSTS:\n        clf = LinearSVC(\n            C=cost, class_weight={1: 2, -1: 1}, penalty=\"l2\",\n            loss=\"squared_hinge\", max_iter=2000,\n        )\n        ap_scores = cross_val_score(\n            clf, feats_train, cls_labels, cv=3, scoring=\"average_precision\",\n        )\n        clf.fit(feats_train, cls_labels)\n\n        # Keep track of best SVM (based on cost) for each class.\n        if ap_scores.mean() > best_crossval_ap:\n            best_crossval_ap = ap_scores.mean()\n            best_crossval_clf = clf\n            best_cost = cost\n\n    logger.info(f\"Best SVM {cls_name}: cost {best_cost}, mAP {best_crossval_ap * 100}\")\n    # fmt: on\n\n    # -------------------------------------------------------------------------\n    #   TEST THE TRAINED SVM (PER CLASS)\n    # -------------------------------------------------------------------------\n    predictions = best_crossval_clf.decision_function(feats_test)\n    evaluate_data_inds = tgts_test != -1\n    eval_preds = predictions[evaluate_data_inds]\n\n    cls_labels = np.copy(tgts_test)\n    eval_cls_labels = cls_labels[evaluate_data_inds]\n    eval_cls_labels[np.where(eval_cls_labels == 0)] = -1\n\n    # Binarize class labels to make AP targets.\n    targets = eval_cls_labels > 0\n    return average_precision_score(targets, eval_preds)\n\n\ndef main(_A: argparse.Namespace):\n\n    if _A.num_gpus_per_machine == 0:\n        # Set device as CPU if num_gpus_per_machine = 0.\n        device = torch.device(\"cpu\")\n    else:\n        # Get the current device (this will be zero here by default).\n        device = torch.cuda.current_device()\n\n    # Create a downstream config object (this will be immutable) and perform\n    # common setup such as logging and setting up serialization directory.\n    _DOWNC = Config(_A.down_config, _A.down_config_override)\n    common_setup(_DOWNC, _A, job_type=\"downstream\")\n\n    # Create a (pretraining) config object and backup in serialization directory.\n    _C = Config(_A.config, _A.config_override)\n    _C.dump(os.path.join(_A.serialization_dir, \"pretrain_config.yaml\"))\n\n    # -------------------------------------------------------------------------\n    #   INSTANTIATE DATALOADER, MODEL, AND FEATURE EXTRACTOR\n    # -------------------------------------------------------------------------\n\n    train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split=\"trainval\")\n    train_dataloader = DataLoader(\n        train_dataset,\n        batch_size=_DOWNC.OPTIM.BATCH_SIZE,\n        num_workers=_A.cpu_workers,\n        pin_memory=True,\n    )\n    test_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split=\"test\")\n    test_dataloader = DataLoader(\n        test_dataset,\n        batch_size=_DOWNC.OPTIM.BATCH_SIZE,\n        num_workers=_A.cpu_workers,\n        pin_memory=True,\n    )\n    NUM_CLASSES = len(train_dataset.class_names)\n\n    # Initialize from a checkpoint, but only keep the visual module.\n    model = PretrainingModelFactory.from_config(_C)\n\n    # Load weights according to the init method, do nothing for `random`, and\n    # `imagenet` is already taken care of.\n    if _A.weight_init == \"virtex\":\n        ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path)\n    elif _A.weight_init == \"torchvision\":\n        # Keep strict=False because this state dict may have weights for\n        # last fc layer.\n        model.visual.cnn.load_state_dict(\n            torch.load(_A.checkpoint_path, map_location=\"cpu\")[\"state_dict\"],\n            strict=False,\n        )\n        # Set ``ITERATION`` to a dummy value.\n        ITERATION = 0\n\n    # Transfer model to GPU and set to eval mode. This is a torchvision model\n    # and it returns features as ``(batch_size, 2048, 7, 7)``.\n    model = model.visual.cnn.to(device).eval()\n\n    # -------------------------------------------------------------------------\n    #   EXTRACT FEATURES FOR TRAINING SVMs\n    # -------------------------------------------------------------------------\n\n    features_train: List[torch.Tensor] = []\n    targets_train: List[torch.Tensor] = []\n\n    features_test: List[torch.Tensor] = []\n    targets_test: List[torch.Tensor] = []\n\n    # VOC07 is small, extract all features and keep them in memory.\n    with torch.no_grad():\n        for batch in tqdm(train_dataloader, desc=\"Extracting train features:\"):\n            features = model(batch[\"image\"].to(device))\n\n            # Global average pool features. Assume the tensor is in NCHW format.\n            if len(features.size()) > 2:\n                # shape: (batch_size, visual_feature_size)\n                features = features.mean(dim=(2, 3))\n\n            # L2-normalize the global average pooled features.\n            features = F.normalize(features, dim=-1)\n\n            features_train.append(features.cpu())\n            targets_train.append(batch[\"label\"])\n\n        # Similarly extract test features.\n        for batch in tqdm(test_dataloader, desc=\"Extracting test features:\"):\n            features = model(batch[\"image\"].to(device))\n\n            if len(features.size()) > 2:\n                features = features.mean(dim=(2, 3))\n\n            features = F.normalize(features, dim=-1)\n\n            features_test.append(features.cpu())\n            targets_test.append(batch[\"label\"])\n\n    # Convert batches of features/targets to one large numpy array\n    features_train = torch.cat(features_train, dim=0).numpy()\n    targets_train = torch.cat(targets_train, dim=0).numpy().astype(np.int32)\n\n    features_test = torch.cat(features_test, dim=0).numpy()\n    targets_test = torch.cat(targets_test, dim=0).numpy().astype(np.int32)\n\n    # -------------------------------------------------------------------------\n    #   TRAIN AND TEST SVMs WITH EXTRACTED FEATURES\n    # -------------------------------------------------------------------------\n\n    input_args: List[Any] = []\n\n    # Iterate over all VOC07 classes and train one-vs-all linear SVMs.\n    for cls_idx in range(NUM_CLASSES):\n        # fmt: off\n        input_args.append((\n            features_train, targets_train[:, cls_idx],\n            features_test, targets_test[:, cls_idx],\n            train_dataset.class_names[cls_idx],\n        ))\n        # fmt: on\n\n    pool = mp.Pool(processes=_A.cpu_workers)\n    pool_output = pool.map(train_test_single_svm, input_args)\n\n    # -------------------------------------------------------------------------\n    #   TENSORBOARD LOGGING (RELEVANT MAINLY FOR weight_init=checkpoint)\n    # -------------------------------------------------------------------------\n\n    # Tensorboard writer for logging mAP scores. This is useful especially\n    # when weight_init=checkpoint (which maybe be coming from a training job).\n    tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)\n\n    # Test set mAP for each class, for features from every layer.\n    test_map = torch.tensor(pool_output).mean()\n    logger.info(f\"Iteration: {ITERATION}, mAP: {test_map * 100}\")\n    tensorboard_writer.add_scalars(\n        \"metrics/voc07_clf\", {f\"voc07_mAP\": test_map * 100}, ITERATION\n    )\n\n\nif __name__ == \"__main__\":\n    _A = parser.parse_args()\n\n    if _A.num_gpus_per_machine > 1:\n        raise ValueError(\"Using multiple GPUs is not supported for this script.\")\n\n    # Add an arg in config override if `--weight-init` is imagenet.\n    if _A.weight_init == \"imagenet\":\n        _A.config_override.extend([\"MODEL.VISUAL.PRETRAINED\", True])\n\n    # No distributed training here, just a single process.\n    main(_A)\n"
  },
  {
    "path": "scripts/eval_captioning.py",
    "content": "import argparse\nimport json\nimport os\nfrom typing import Any, Dict, List\n\nfrom loguru import logger\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom virtex.config import Config\nfrom virtex.data import ImageDirectoryDataset\nfrom virtex.factories import TokenizerFactory, PretrainingModelFactory\nfrom virtex.utils.checkpointing import CheckpointManager\nfrom virtex.utils.common import common_parser\nfrom virtex.utils.metrics import CocoCaptionsEvaluator\n\n\n# fmt: off\nparser = common_parser(\n    description=\"\"\"Run image captioning inference on a pretrained model, and/or\n    evaluate pretrained model on COCO Captions val2017 split.\"\"\"\n)\nparser.add_argument(\n    \"--images\", \"--data-root\", default=None,\n    help=\"\"\"Path to a directory containing image files to generate captions for.\n    Default: COCO val2017 image directory as expected relative to project root.\"\"\"\n)\nparser.add_argument(\n    \"--checkpoint-path\", required=True,\n    help=\"Path to load checkpoint and run captioning evaluation.\"\n)\nparser.add_argument(\n    \"--output\", default=None,\n    help=\"Path to save predictions as a JSON file.\"\n)\nparser.add_argument(\n    \"--calc-metrics\", action=\"store_true\",\n    help=\"\"\"Calculate CIDEr and SPICE metrics using ground truth COCO Captions.\n    This flag should not be set when running inference on arbitrary images.\"\"\"\n)\n# fmt: on\n\n\ndef main(_A: argparse.Namespace):\n\n    if _A.num_gpus_per_machine == 0:\n        # Set device as CPU if num_gpus_per_machine = 0.\n        device = torch.device(\"cpu\")\n    else:\n        # Get the current device (this will be zero here by default).\n        device = torch.cuda.current_device()\n\n    _C = Config(_A.config, _A.config_override)\n\n    tokenizer = TokenizerFactory.from_config(_C)\n\n    if _A.data_root is None:\n        _A.data_root = os.path.join(_C.DATA.ROOT, \"val2017\")\n\n    val_dataloader = DataLoader(\n        ImageDirectoryDataset(_A.data_root),\n        batch_size=_C.OPTIM.BATCH_SIZE,\n        num_workers=_A.cpu_workers,\n        pin_memory=True,\n    )\n    # Initialize model from a checkpoint.\n    model = PretrainingModelFactory.from_config(_C).to(device)\n    ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path)\n    model.eval()\n\n    # Make a list of predictions to evaluate.\n    predictions: List[Dict[str, Any]] = []\n\n    for val_iteration, val_batch in enumerate(val_dataloader, start=1):\n\n        val_batch[\"image\"] = val_batch[\"image\"].to(device)\n        with torch.no_grad():\n            output_dict = model(val_batch)\n\n        # Make a dictionary of predictions in COCO format.\n        for image_id, caption in zip(\n            val_batch[\"image_id\"], output_dict[\"predictions\"]\n        ):\n            predictions.append(\n                {\n                    # Convert image id to int if possible (mainly for COCO eval).\n                    \"image_id\": int(image_id) if image_id.isdigit() else image_id,\n                    \"caption\": tokenizer.decode(caption.tolist()),\n                }\n            )\n\n    logger.info(\"Displaying first 25 caption predictions:\")\n    for pred in predictions[:25]:\n        logger.info(f\"{pred['image_id']} :: {pred['caption']}\")\n\n    # Save predictions as a JSON file if specified.\n    if _A.output is not None:\n        os.makedirs(os.path.dirname(_A.output), exist_ok=True)\n        json.dump(predictions, open(_A.output, \"w\"))\n        logger.info(f\"Saved predictions to {_A.output}\")\n\n    # Calculate CIDEr and SPICE metrics using ground truth COCO Captions. This\n    # should be skipped when running inference on arbitrary images.\n    if _A.calc_metrics:\n        # Assume ground truth (COCO val2017 annotations) exist.\n        gt = os.path.join(_C.DATA.ROOT, \"annotations\", \"captions_val2017.json\")\n\n        metrics = CocoCaptionsEvaluator(gt).evaluate(predictions)\n        logger.info(f\"Iter: {ITERATION} | Metrics: {metrics}\")\n\n\nif __name__ == \"__main__\":\n    _A = parser.parse_args()\n    if _A.num_gpus_per_machine > 1:\n        raise ValueError(\"Using multiple GPUs is not supported for this script.\")\n\n    # No distributed training here, just a single process.\n    main(_A)\n"
  },
  {
    "path": "scripts/eval_detectron2.py",
    "content": "\"\"\"\nFinetune a pre-trained model on a downstream task, one of those available in\nDetectron2.\nSupported downstream:\n  - LVIS Instance Segmentation\n  - COCO Instance Segmentation\n  - Pascal VOC 2007+12 Object Detection\n\nReference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py\nThanks to the developers of Detectron2!\n\"\"\"\nimport argparse\nimport os\nimport re\n\nimport torch\nfrom torch.utils.tensorboard import SummaryWriter\n\nimport detectron2 as d2\nfrom detectron2.checkpoint import DetectionCheckpointer\nfrom detectron2.engine import DefaultTrainer, default_setup\nfrom detectron2.evaluation import (\n    LVISEvaluator,\n    PascalVOCDetectionEvaluator,\n    COCOEvaluator,\n)\nfrom detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads\n\nfrom virtex.config import Config\nfrom virtex.factories import PretrainingModelFactory\nfrom virtex.utils.checkpointing import CheckpointManager\nfrom virtex.utils.common import common_parser\nimport virtex.utils.distributed as dist\n\n# fmt: off\nparser = common_parser(\n    description=\"Train object detectors from pretrained visual backbone.\"\n)\nparser.add_argument(\n    \"--d2-config\", required=True,\n    help=\"Path to a detectron2 config for downstream task finetuning.\"\n)\nparser.add_argument(\n    \"--d2-config-override\", nargs=\"*\", default=[],\n    help=\"\"\"Key-value pairs from Detectron2 config to override from file.\n    Some keys will be ignored because they are set from other args:\n    [DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD,\n    TEST.EVAL_PERIOD, OUTPUT_DIR]\"\"\",\n)\n\nparser.add_argument_group(\"Checkpointing and Logging\")\nparser.add_argument(\n    \"--weight-init\", choices=[\"random\", \"imagenet\", \"torchvision\", \"virtex\"],\n    default=\"virtex\", help=\"\"\"How to initialize weights:\n        1. 'random' initializes all weights randomly\n        2. 'imagenet' initializes backbone weights from torchvision model zoo\n        3. {'torchvision', 'virtex'} load state dict from --checkpoint-path\n            - with 'torchvision', state dict would be from PyTorch's training\n              script.\n            - with 'virtex' it should be for our full pretrained model.\"\"\"\n)\nparser.add_argument(\n    \"--checkpoint-path\",\n    help=\"Path to load checkpoint and run downstream task evaluation.\"\n)\nparser.add_argument(\n    \"--resume\", action=\"store_true\", help=\"\"\"Specify this flag when resuming\n    training from a checkpoint saved by Detectron2.\"\"\"\n)\nparser.add_argument(\n    \"--eval-only\", action=\"store_true\",\n    help=\"Skip training and evaluate checkpoint provided at --checkpoint-path.\",\n)\nparser.add_argument(\n    \"--checkpoint-every\", type=int, default=5000,\n    help=\"Serialize model to a checkpoint after every these many iterations.\",\n)\n# fmt: on\n\n\n@ROI_HEADS_REGISTRY.register()\nclass Res5ROIHeadsExtraNorm(Res5ROIHeads):\n    r\"\"\"\n    ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN\n    C4/DC5 backbones for VOC detection.\n    \"\"\"\n\n    def _build_res5_block(self, cfg):\n        seq, out_channels = super()._build_res5_block(cfg)\n        norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels)\n        seq.add_module(\"norm\", norm)\n        return seq, out_channels\n\n\ndef build_detectron2_config(_C: Config, _A: argparse.Namespace):\n    r\"\"\"Build detectron2 config based on our pre-training config and args.\"\"\"\n    _D2C = d2.config.get_cfg()\n\n    # Override some default values based on our config file.\n    _D2C.merge_from_file(_A.d2_config)\n    _D2C.merge_from_list(_A.d2_config_override)\n\n    # Set some config parameters from args.\n    _D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers\n    _D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every\n    _D2C.OUTPUT_DIR = _A.serialization_dir\n\n    # Set ResNet depth to override in Detectron2's config.\n    _D2C.MODEL.RESNETS.DEPTH = int(\n        re.search(r\"resnet(\\d+)\", _C.MODEL.VISUAL.NAME).group(1)\n        if \"torchvision\" in _C.MODEL.VISUAL.NAME\n        else re.search(r\"_R_(\\d+)\", _C.MODEL.VISUAL.NAME).group(1)\n        if \"detectron2\" in _C.MODEL.VISUAL.NAME\n        else 0\n    )\n    return _D2C\n\n\nclass DownstreamTrainer(DefaultTrainer):\n    r\"\"\"\n    Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks.\n\n    Arguments:\n        cfg (detectron2.config.CfgNode): Detectron2 config object.\n        weights (Union[str, Dict]): Weights to load in the initialized model.\n            If ``str``, then we assume path to a checkpoint, or if a ``dict``,\n            we assume a state dict. This will be an ``str`` only if training\n            is resumed from a Detectron2 checkpoint.\n    \"\"\"\n\n    def __init__(self, cfg, weights):\n\n        super().__init__(cfg)\n\n        # Load pre-trained weights before wrapping to DDP because `ApexDDP` has\n        # some weird issue with `DetectionCheckpointer`.\n        # fmt: off\n        if isinstance(weights, str):\n            # weights are ``str`` means ImageNet init or resume training.\n            self.start_iter = (\n                DetectionCheckpointer(\n                    self._trainer.model,\n                    optimizer=self._trainer.optimizer,\n                    scheduler=self.scheduler\n                ).resume_or_load(weights, resume=True).get(\"iteration\", -1) + 1\n            )\n        elif isinstance(weights, dict):\n            # weights are a state dict means our pretrain init.\n            DetectionCheckpointer(self._trainer.model)._load_model(weights)\n        # fmt: on\n\n    @classmethod\n    def build_evaluator(cls, cfg, dataset_name, output_folder=None):\n        if output_folder is None:\n            output_folder = os.path.join(cfg.OUTPUT_DIR, \"inference\")\n        evaluator_list = []\n        evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type\n        if evaluator_type == \"pascal_voc\":\n            return PascalVOCDetectionEvaluator(dataset_name)\n        elif evaluator_type == \"coco\":\n            return COCOEvaluator(dataset_name, cfg, True, output_folder)\n        elif evaluator_type == \"lvis\":\n            return LVISEvaluator(dataset_name, cfg, True, output_folder)\n\n    def test(self, cfg=None, model=None, evaluators=None):\n        r\"\"\"Evaluate the model and log results to stdout and tensorboard.\"\"\"\n        cfg = cfg or self.cfg\n        model = model or self.model\n\n        tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)\n        results = super().test(cfg, model)\n        flat_results = d2.evaluation.testing.flatten_results_dict(results)\n        for k, v in flat_results.items():\n            tensorboard_writer.add_scalar(k, v, self.start_iter)\n\n\ndef main(_A: argparse.Namespace):\n\n    # Local process group is needed for detectron2.\n    pg = list(range(dist.get_world_size()))\n    d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg)\n\n    # Create a config object (this will be immutable) and perform common setup\n    # such as logging and setting up serialization directory.\n    if _A.weight_init == \"imagenet\":\n        _A.config_override.extend([\"MODEL.VISUAL.PRETRAINED\", True])\n    _C = Config(_A.config, _A.config_override)\n\n    # We use `default_setup` from detectron2 to do some common setup, such as\n    # logging, setting up serialization etc. For more info, look into source.\n    _D2C = build_detectron2_config(_C, _A)\n    default_setup(_D2C, _A)\n\n    # Prepare weights to pass in instantiation call of trainer.\n    if _A.weight_init in {\"virtex\", \"torchvision\"}:\n        if _A.resume:\n            # If resuming training, let detectron2 load weights by providing path.\n            model = None\n            weights = _A.checkpoint_path\n        else:\n            # Load backbone weights from VirTex pretrained checkpoint.\n            model = PretrainingModelFactory.from_config(_C)\n            if _A.weight_init == \"virtex\":\n                CheckpointManager(model=model).load(_A.checkpoint_path)\n            else:\n                model.visual.cnn.load_state_dict(\n                    torch.load(_A.checkpoint_path, map_location=\"cpu\")[\"state_dict\"],\n                    strict=False,\n                )\n            weights = model.visual.detectron2_backbone_state_dict()\n    else:\n        # If random or imagenet init, just load weights after initializing model.\n        model = PretrainingModelFactory.from_config(_C)\n        weights = model.visual.detectron2_backbone_state_dict()\n\n    # Back up pretrain config and model checkpoint (if provided).\n    _C.dump(os.path.join(_A.serialization_dir, \"pretrain_config.yaml\"))\n    if _A.weight_init == \"virtex\" and not _A.resume:\n        torch.save(\n            model.state_dict(),\n            os.path.join(_A.serialization_dir, \"pretrain_model.pth\"),\n        )\n\n    del model\n    trainer = DownstreamTrainer(_D2C, weights)\n    trainer.test() if _A.eval_only else trainer.train()\n\n\nif __name__ == \"__main__\":\n    _A = parser.parse_args()\n\n    # This will launch `main` and set appropriate CUDA device (GPU ID) as\n    # per process (accessed in the beginning of `main`).\n    dist.launch(\n        main,\n        num_machines=_A.num_machines,\n        num_gpus_per_machine=_A.num_gpus_per_machine,\n        machine_rank=_A.machine_rank,\n        dist_url=_A.dist_url,\n        args=(_A, ),\n    )\n"
  },
  {
    "path": "scripts/pretrain_virtex.py",
    "content": "import argparse\nfrom collections import Counter\nfrom typing import Any\n\nfrom loguru import logger\nimport torch\nfrom torch import nn\nfrom torch.cuda import amp\nfrom torch.utils.data import DataLoader, DistributedSampler\nfrom torch.utils.tensorboard import SummaryWriter\n\n# fmt: off\nfrom virtex.config import Config\nfrom virtex.factories import (\n    PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory,\n    LRSchedulerFactory,\n)\nfrom virtex.utils.checkpointing import CheckpointManager\nfrom virtex.utils.common import common_parser, common_setup, cycle\nimport virtex.utils.distributed as dist\nfrom virtex.utils.timer import Timer\n\n\nparser = common_parser(\n    description=\"Train a VirTex model (CNN + Transformer) on COCO Captions.\"\n)\ngroup = parser.add_argument_group(\"Checkpointing and Logging\")\ngroup.add_argument(\n    \"--resume-from\", default=None,\n    help=\"Path to a checkpoint to resume training from (if provided).\"\n)\ngroup.add_argument(\n    \"--checkpoint-every\", type=int, default=2000,\n    help=\"Serialize model to a checkpoint after every these many iterations.\",\n)\ngroup.add_argument(\n    \"--log-every\", type=int, default=20,\n    help=\"\"\"Log training curves to tensorboard after every these many iterations\n    only master process logs averaged loss values across processes.\"\"\",\n)\n# fmt: on\n\n\ndef main(_A: argparse.Namespace):\n\n    if _A.num_gpus_per_machine == 0:\n        # Set device as CPU if num_gpus_per_machine = 0.\n        device: Any = torch.device(\"cpu\")\n    else:\n        # Get the current device as set for current distributed process.\n        # Check `launch` function in `virtex.utils.distributed` module.\n        device = torch.cuda.current_device()\n\n    # Create a config object (this will be immutable) and perform common setup\n    # such as logging and setting up serialization directory.\n    _C = Config(_A.config, _A.config_override)\n    common_setup(_C, _A)\n\n    # -------------------------------------------------------------------------\n    #   INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER\n    # -------------------------------------------------------------------------\n    train_dataset = PretrainingDatasetFactory.from_config(_C, split=\"train\")\n    val_dataset = PretrainingDatasetFactory.from_config(_C, split=\"val\")\n\n    # Make `DistributedSampler`s to shard datasets across GPU processes.\n    # Skip this if training on CPUs.\n    train_sampler = (\n        DistributedSampler(train_dataset, shuffle=True)  # type: ignore\n        if _A.num_gpus_per_machine > 0\n        else None\n    )\n    val_sampler = (\n        DistributedSampler(val_dataset, shuffle=False)  # type: ignore\n        if _A.num_gpus_per_machine > 0\n        else None\n    )\n    train_dataloader = DataLoader(\n        train_dataset,\n        batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(),\n        sampler=train_sampler,\n        shuffle=train_sampler is None,\n        num_workers=_A.cpu_workers,\n        pin_memory=True,\n        drop_last=True,\n        collate_fn=train_dataset.collate_fn,\n    )\n    val_dataloader = DataLoader(\n        val_dataset,\n        batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(),\n        sampler=val_sampler,\n        shuffle=False,\n        num_workers=_A.cpu_workers,\n        pin_memory=True,\n        drop_last=False,\n        collate_fn=val_dataset.collate_fn,\n    )\n\n    model = PretrainingModelFactory.from_config(_C).to(device)\n    optimizer = OptimizerFactory.from_config(_C, model.named_parameters())\n    scheduler = LRSchedulerFactory.from_config(_C, optimizer)\n\n    # -------------------------------------------------------------------------\n    #   BEFORE TRAINING STARTS\n    # -------------------------------------------------------------------------\n\n    # Create a gradient scaler for automatic mixed precision.\n    scaler = amp.GradScaler(enabled=_C.AMP)\n\n    # Load checkpoint to resume training if specified.\n    if _A.resume_from is not None:\n        start_iteration = CheckpointManager(\n            model=model, optimizer=optimizer, scheduler=scheduler, scaler=scaler,\n        ).load(_A.resume_from)\n    else:\n        start_iteration = 0\n\n    # Create an iterator from dataloader to sample batches perpetually.\n    train_dataloader_iter = cycle(train_dataloader, device, start_iteration)\n\n    # Wrap model in DDP if using more than one processes.\n    if dist.get_world_size() > 1:\n        dist.synchronize()\n        model = nn.parallel.DistributedDataParallel(model, device_ids=[device])\n\n    # Keep track of time per iteration and ETA.\n    timer = Timer(\n        start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS\n    )\n    # Create tensorboard writer and checkpoint manager (only in master process).\n    if dist.is_master_process():\n        tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir)\n        tensorboard_writer.add_text(\"config\", f\"```\\n{_C}\\n```\")\n\n        checkpoint_manager = CheckpointManager(\n            _A.serialization_dir,\n            model=model,\n            optimizer=optimizer,\n            scheduler=scheduler,\n            scaler=scaler,\n        )\n\n    # -------------------------------------------------------------------------\n    #   TRAINING LOOP\n    # -------------------------------------------------------------------------\n    for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1):\n        timer.tic()\n        optimizer.zero_grad()\n        batch = next(train_dataloader_iter)\n\n        with amp.autocast(enabled=_C.AMP):\n            output_dict = model(batch)\n            loss = output_dict[\"loss\"]\n\n        scaler.scale(loss).backward()\n\n        # First clip norm of gradients, and then perform optimizer step.\n        scaler.unscale_(optimizer)\n        torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM)\n        scaler.step(optimizer)\n\n        scaler.update()\n        scheduler.step()\n        timer.toc()\n\n        # ---------------------------------------------------------------------\n        #   LOGGING\n        # ---------------------------------------------------------------------\n        if iteration % _A.log_every == 0:\n            logger.info(\n                f\"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]\"\n            )\n            if dist.is_master_process():\n                tensorboard_writer.add_scalars(\n                    \"learning_rate\",\n                    {\n                        \"visual\": optimizer.param_groups[0][\"lr\"],\n                        \"common\": optimizer.param_groups[-1][\"lr\"],\n                    },\n                    iteration,\n                )\n                tensorboard_writer.add_scalars(\n                    \"train\", output_dict[\"loss_components\"], iteration\n                )\n\n        # ---------------------------------------------------------------------\n        #   VALIDATION\n        # ---------------------------------------------------------------------\n        if iteration % _A.checkpoint_every == 0:\n            if dist.is_master_process():\n                checkpoint_manager.step(iteration)\n\n            # All processes will wait till master process is done serializing.\n            dist.synchronize()\n\n            torch.set_grad_enabled(False)\n            model.eval()\n\n            # Accumulate different val loss components according to the type of\n            # pretraining model.\n            val_loss_counter: Counter = Counter()\n\n            for val_iteration, val_batch in enumerate(val_dataloader, start=1):\n                for key in val_batch:\n                    val_batch[key] = val_batch[key].to(device)\n                output_dict = model(val_batch)\n\n                val_loss_counter.update(output_dict[\"loss_components\"])\n\n            # Divide each loss component by number of val batches per GPU.\n            val_loss_dict = {\n                k: v / val_iteration for k, v in dict(val_loss_counter).items()\n            }\n            dist.average_across_processes(val_loss_dict)\n            torch.set_grad_enabled(True)\n            model.train()\n\n            logger.info(f\"Iteration: {iteration} [Val loss: {val_loss_dict}]\")\n            if dist.is_master_process():\n                tensorboard_writer.add_scalars(\"val\", val_loss_dict, iteration)\n\n\nif __name__ == \"__main__\":\n    _A = parser.parse_args()\n\n    if _A.num_gpus_per_machine == 0:\n        main(_A)\n    else:\n        # This will launch `main` and set appropriate CUDA device (GPU ID) as\n        # per process (accessed in the beginning of `main`).\n        dist.launch(\n            main,\n            num_machines=_A.num_machines,\n            num_gpus_per_machine=_A.num_gpus_per_machine,\n            machine_rank=_A.machine_rank,\n            dist_url=_A.dist_url,\n            args=(_A, ),\n        )\n"
  },
  {
    "path": "setup.py",
    "content": "#!/usr/bin/env python\nimport glob\nimport os\nfrom setuptools import setup\nimport shutil\nfrom typing import List\n\n\ndef get_model_zoo_configs() -> List[str]:\n    \"\"\"\n    Return a list of configs to include in package for model zoo. Copy over\n    these configs inside virtex/model_zoo.\n    \"\"\"\n\n    # Use absolute paths while symlinking.\n    source_configs_dir = os.path.join(\n        os.path.dirname(os.path.realpath(__file__)), \"configs\"\n    )\n    destination = os.path.join(\n        os.path.dirname(os.path.realpath(__file__)), \"virtex\", \"model_zoo\", \"configs\"\n    )\n    # Symlink the config directory inside package to have a cleaner pip install.\n\n    # Remove stale symlink/directory from a previous build.\n    if os.path.exists(source_configs_dir):\n        if os.path.islink(destination):\n            os.unlink(destination)\n        elif os.path.isdir(destination):\n            shutil.rmtree(destination)\n\n    if not os.path.exists(destination):\n        try:\n            os.symlink(source_configs_dir, destination)\n        except OSError:\n            # Fall back to copying if symlink fails: ex. on Windows.\n            shutil.copytree(source_configs_dir, destination)\n\n    config_paths = glob.glob(\"configs/**/*.yaml\", recursive=True)\n    return config_paths\n\n\nsetup(\n    name=\"virtex\",\n    version=\"1.4.0\",\n    author=\"Karan Desai and Justin Johnson\",\n    description=\"VirTex: Learning Visual Representations with Textual Annotations\",\n    package_data={\"virtex.model_zoo\": get_model_zoo_configs()},\n    python_requires=\">=3.8\",\n    license=\"MIT\",\n    zip_safe=True,\n)\n"
  },
  {
    "path": "virtex/__init__.py",
    "content": ""
  },
  {
    "path": "virtex/config.py",
    "content": "from typing import Any, List, Optional\n\nfrom fvcore.common.config import CfgNode as CN\n\n\nclass Config:\n    r\"\"\"\n    This class provides package-wide configuration management. It is a\n    nested dict-like structure with nested keys accessible as attributes. It\n    contains sensible default values, which can be modified by (first) a YAML\n    file and (second) a list of attributes and values.\n\n    An instantiated object is immutable: modifying any attribute is illegal.\n    You must override required parameter values either through ``config_file``\n    or ``override_list`` arguments.\n\n    Args:\n        config_file: Path to a YAML file containing config parameters.\n        config_override: A list of sequential attributes and values of parameters.\n            This happens after overriding from YAML file.\n\n    Examples:\n        Let a YAML file named \"config.yaml\" specify these parameters to override::\n\n            OPTIM:\n            BATCH_SIZE: 512\n            LR: 0.01\n\n        >>> _C = Config(\"config.yaml\", [\"OPTIM.BATCH_SIZE\", 1024])\n        >>> _C.LR  # default: 0.001\n        0.01\n        >>> _C.OPTIM.BATCH_SIZE  # default: 256, file: 512\n        1024\n    \"\"\"\n\n    def __init__(\n        self, config_file: Optional[str] = None, override_list: List[Any] = []\n    ):\n        _C = CN()\n\n        # Random seed for NumPy and PyTorch, important for reproducibility.\n        _C.RANDOM_SEED = 0\n        # Train with Automatic Mixed Precision (native PyTorch).\n        _C.AMP = True\n        # Set CUDNN deterministic flag (torch.backends.cudnn.deterministic).\n        # Setting this will ensure exact results on every run at the cost of\n        # little slowdown. Good for debugging.\n        _C.CUDNN_DETERMINISTIC = False\n        # Set CUDNN benchmark flag (torch.backends.cudnn.benchmark). Enables\n        # CUDNN to select fastest implementation for operations based on GPU.\n        # May change results (in decimals) on different hardware, but faster\n        # to train. Turn off while debugging.\n        _C.CUDNN_BENCHMARK = True\n\n        # ---------------------------------------------------------------------\n        #   Data paths and parameters related to dataloading.\n        # ---------------------------------------------------------------------\n        _C.DATA = CN()\n\n        # Path to the dataset root, which structure as per README. Path is\n        # assumed to be relative to project root.\n        _C.DATA.ROOT = \"datasets/coco\"\n        # Path to .model file generated by ``sentencepiece``.\n        _C.DATA.TOKENIZER_MODEL = \"datasets/vocab/coco_10k.model\"\n\n        # Handy config params for vocab size and indices of special tokens.\n        # While these can be picked up from the tokenizer, having these in\n        # the config makes it easy to create a model without instantiating too\n        # many tokenizer instances (especially when not needed, e.g. model zoo).\n        # These must match according to what's present in ``TOKENIZER_VOCAB``\n        # and ``TOKENIZER_MODEL`` above.\n        _C.DATA.VOCAB_SIZE = 10000\n        # Index of out-of-vocabulary (and padding) token.\n        _C.DATA.UNK_INDEX = 0\n        # Index of the start-of-sentence [SOS] token.\n        _C.DATA.SOS_INDEX = 1\n        # Index of the end-of-sentence [EOS] token.\n        _C.DATA.EOS_INDEX = 2\n        # Index of the word masking token. While not used for captioning, having\n        # this extra token makes it possible to train an MLM model without\n        # re-creating a new vocab mapping.\n        _C.DATA.MASK_INDEX = 3\n\n        # Size of the image (square) to crop from original input image.\n        _C.DATA.IMAGE_CROP_SIZE = 224\n        # Maximum length of input caption (number of tokens).\n        # Longer captions will be truncated up to this length.\n        _C.DATA.MAX_CAPTION_LENGTH = 30\n\n        # List of image transforms (pre-processing and data augmentation) to be\n        # applied sequentially (always or randomly) during training and\n        # validation. Refer ``virtex/facetories.py`` for all possible transforms.\n        _C.DATA.IMAGE_TRANSFORM_TRAIN = [\n            \"random_resized_crop\",\n            \"horizontal_flip\",\n            \"color_jitter\",\n            \"normalize\",\n        ]\n        _C.DATA.IMAGE_TRANSFORM_VAL = [\n            \"smallest_resize\",\n            \"center_crop\",\n            \"normalize\",\n        ]\n\n        # Hyper-parameters for masked LM pretraining task. These are only used\n        # when ``MODEL.NAME`` is \"masked_lm\".\n        _C.DATA.MASKED_LM = CN()\n        # Fraction of tokens to choose for masking, this must be less than 1.\n        _C.DATA.MASKED_LM.MASK_PROPORTION = 0.15\n        # Probability to replace chosen tokens with [MASK] token.\n        _C.DATA.MASKED_LM.MASK_PROBABILITY = 0.85\n        # Probability to replace chosen tokens with a random token.\n        _C.DATA.MASKED_LM.REPLACE_PROBABILITY = 0.10\n\n        # ---------------------------------------------------------------------\n        #   Model architecture: visual backbone and textual head.\n        # ---------------------------------------------------------------------\n        _C.MODEL = CN()\n\n        # Name of model, based on pretraining task.\n        # Possible choices: {\"token_classification\", \"multilabel_classification\",\n        # \"captioning\", \"bicaptioning\", \"masked_lm\", \"virtex\"}\n        _C.MODEL.NAME = \"virtex\"\n\n        _C.MODEL.VISUAL = CN()\n        # Name of visual backbone. Possible choices: {\"blind\", \"torchvision\"}\n        # Models from torchvision can be specified as shown below.\n        _C.MODEL.VISUAL.NAME = \"torchvision::resnet50\"\n        # Number of channels in pooled spatial features of visual backbone.\n        _C.MODEL.VISUAL.FEATURE_SIZE = 2048\n        # Whether to load ImageNet pretrained weights into visual backbone.\n        _C.MODEL.VISUAL.PRETRAINED = False\n        # Whether to keep visual backbone frozen and train only textual head.\n        _C.MODEL.VISUAL.FROZEN = False\n\n        _C.MODEL.TEXTUAL = CN()\n        # Name of textual head. Set to \"none\" for MODEL.NAME = \"*_classification\".\n        # Possible choices: {\"transdec_postnorm\", \"transdec_prenorm\"}.\n        # Architectural hyper-parameters are specified as shown above.\n        _C.MODEL.TEXTUAL.NAME = \"transdec_postnorm::L1_H2048_A32_F8192\"\n        # L = Number of layers in the transformer.\n        # H = Hidden size of the transformer (embeddings, attention features).\n        # A = Number of attention heads in the transformer.\n        # F = Size of feedforward layers in the transformer.\n        # Typically, we have (A = H / 64) and (F = 4 * H).\n\n        # Dropout probability for embedding, hidden features in textual head.\n        _C.MODEL.TEXTUAL.DROPOUT = 0.1\n\n        _C.MODEL.DECODER = CN()\n        # What algorithm to use for decoding. Supported values: {\"beam_search\",\n        # \"nucleus_sampling\"}.\n        _C.MODEL.DECODER.NAME = \"beam_search\"\n        # Number of beams to decode (1 = greedy decoding). Ignored when decoding\n        # through nucleus sampling.\n        _C.MODEL.DECODER.BEAM_SIZE = 5\n        # Size of nucleus for sampling predictions. Ignored when decoding through\n        # beam search.\n        _C.MODEL.DECODER.NUCLEUS_SIZE = 0.9\n        # Maximum length of decoded caption. Decoding may end earlier when [EOS]\n        # token is sampled.\n        _C.MODEL.DECODER.MAX_DECODING_STEPS = _C.DATA.MAX_CAPTION_LENGTH\n\n        # ---------------------------------------------------------------------\n        #   Optimization hyper-parameters, default values are for pretraining\n        #   our best model on bicaptioning task (COCO Captions).\n        # ---------------------------------------------------------------------\n        _C.OPTIM = CN()\n\n        # Name of optimizer to use. Supported values: {\"sgd\", \"adamw\"}.\n        # AdamW uses default (beta1, beta2) values from PyTorch.\n        _C.OPTIM.OPTIMIZER_NAME = \"sgd\"\n        # Momentum co-efficient for SGD. Ignored for AdamW.\n        _C.OPTIM.SGD_MOMENTUM = 0.9\n        # Weight decay co-efficient for the optimizer.\n        _C.OPTIM.WEIGHT_DECAY = 0.0001\n        # Regex pattern of params for which there will be no weight decay.\n        _C.OPTIM.NO_DECAY = \".*textual.(embedding|transformer).*(norm.*|bias)\"\n        # Max gradient norm for clipping to avoid exploding gradients.\n        _C.OPTIM.CLIP_GRAD_NORM = 10.0\n\n        # Wrap our optimizer with Lookahead (https://arxiv.org/abs/1907.08610).\n        _C.OPTIM.LOOKAHEAD = CN()\n        _C.OPTIM.LOOKAHEAD.USE = True\n        _C.OPTIM.LOOKAHEAD.ALPHA = 0.5\n        _C.OPTIM.LOOKAHEAD.STEPS = 5\n\n        # We set different learning rates for CNN (visual backbone) and rest of\n        # the model. CNN LR is typically much higher for training from scratch.\n        # Both LRs undergo same warmup-decay schedules.\n\n        # Total batch size (will be distributed evenly across GPUs).\n        _C.OPTIM.BATCH_SIZE = 256\n        # Max learning rate for CNN (visual backbone).\n        _C.OPTIM.CNN_LR = 0.2\n        # Max learning rate for rest of the model.\n        _C.OPTIM.LR = 0.001\n        # Number of iterations to train for, batches are randomly sampled.\n        _C.OPTIM.NUM_ITERATIONS = 500000\n\n        # Number of steps at the start of training for linear LR warmup.\n        _C.OPTIM.WARMUP_STEPS = 10000\n        # Learning rate annealing schedule for decay after warmup.\n        # Possible choices: {\"none\", \"linear\", \"cosine\", \"multistep\"}.\n        _C.OPTIM.LR_DECAY_NAME = \"cosine\"\n        # Steps to decay LR for \"multistep\" schedule.\n        _C.OPTIM.LR_STEPS = []\n        # Factor to multiply with LR for \"multistep\" schedule.\n        _C.OPTIM.LR_GAMMA = 0.1\n\n        # Override parameter values from YAML file first, then from override\n        # list, then add derived params.\n        self._C = _C\n        if config_file is not None:\n            self._C.merge_from_file(config_file)\n        self._C.merge_from_list(override_list)\n\n        # Make an instantiated object of this class immutable.\n        self._C.freeze()\n\n    def dump(self, file_path: str):\n        r\"\"\"Save config at the specified file path.\n\n        Args:\n            file_path: Path to save config file (YAML).\n        \"\"\"\n        self._C.dump(stream=open(file_path, \"w\"))\n\n    def __getattr__(self, attr: str):\n        return self._C.__getattr__(attr)\n\n    def __str__(self):\n        return self._C.__str__()\n\n    def __repr__(self):\n        return self._C.__repr__()\n"
  },
  {
    "path": "virtex/data/__init__.py",
    "content": "from .datasets.captioning import CaptioningDataset\nfrom .datasets.classification import (\n    TokenClassificationDataset,\n    MultiLabelClassificationDataset,\n)\nfrom .datasets.masked_lm import MaskedLmDataset\nfrom .datasets.downstream import (\n    ImageNetDataset,\n    INaturalist2018Dataset,\n    VOC07ClassificationDataset,\n    ImageDirectoryDataset,\n)\n\n__all__ = [\n    \"CocoCaptionsDataset\",\n    \"CaptioningDataset\",\n    \"TokenClassificationDataset\",\n    \"MultiLabelClassificationDataset\",\n    \"MaskedLmDataset\",\n    \"ImageDirectoryDataset\",\n    \"ImageNetDataset\",\n    \"INaturalist2018Dataset\",\n    \"VOC07ClassificationDataset\",\n]\n"
  },
  {
    "path": "virtex/data/datasets/captioning.py",
    "content": "import random\nfrom typing import Callable, Dict, List\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom virtex.data.tokenizers import SentencePieceBPETokenizer\nfrom virtex.data import transforms as T\nfrom .coco_captions import CocoCaptionsDataset\n\n\nclass CaptioningDataset(Dataset):\n    r\"\"\"\n    A dataset which provides image-caption (forward and backward) pairs from\n    a COCO Captions annotation file. This is used for pretraining tasks which\n    use captions - bicaptioning, forward captioning and token classification.\n\n    Args:\n        data_root: Path to dataset directory containing images and annotations.\n        split: Name of COCO 2017 split to read. One of ``{\"train\", \"val\"}``.\n        tokenizer: Tokenizer which maps word tokens to their integer IDs.\n        image_transform: List of image transformations, from either\n            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_\n            or :mod:`virtex.data.transforms`.\n        max_caption_length: Maximum number of tokens to keep in caption tokens.\n            Extra tokens will be trimmed from the right end of the token list.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_root: str,\n        split: str,\n        tokenizer: SentencePieceBPETokenizer,\n        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,\n        max_caption_length: int = 30,\n    ):\n        self._dset = CocoCaptionsDataset(data_root, split)\n        self.tokenizer = tokenizer\n        self.image_transform = image_transform\n        self.max_caption_length = max_caption_length\n\n        # Short handles for common tokens for convenience:\n        self.padding_idx = tokenizer.token_to_id(\"<unk>\")\n        self.sos_id = tokenizer.token_to_id(\"[SOS]\")\n        self.eos_id = tokenizer.token_to_id(\"[EOS]\")\n\n    def __len__(self):\n        return len(self._dset)\n\n    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:\n\n        # keys: {\"image_id\", \"image\", \"captions\"}\n        instance = self._dset[idx]\n        image_id, image, captions = (\n            instance[\"image_id\"],\n            instance[\"image\"],\n            instance[\"captions\"],\n        )\n        caption = random.choice(captions)\n\n        # Transform image-caption pair and convert image from HWC to CHW format.\n        # Pass in caption to image_transform due to paired horizontal flip.\n        # Caption won't be tokenized/processed here.\n        image_caption = self.image_transform(image=image, caption=caption)\n        image, caption = image_caption[\"image\"], image_caption[\"caption\"]\n        image = np.transpose(image, (2, 0, 1))\n\n        caption_tokens = [self.sos_id, *self.tokenizer.encode(caption), self.eos_id]\n        caption_tokens = caption_tokens[: self.max_caption_length]\n        return {\n            \"image_id\": torch.tensor(image_id, dtype=torch.long),\n            \"image\": torch.tensor(image, dtype=torch.float),\n            \"caption_tokens\": torch.tensor(caption_tokens, dtype=torch.long),\n            \"noitpac_tokens\": torch.tensor(caption_tokens, dtype=torch.long).flip(0),\n            \"caption_lengths\": torch.tensor(len(caption_tokens), dtype=torch.long),\n        }\n\n    def collate_fn(\n        self, data: List[Dict[str, torch.Tensor]]\n    ) -> Dict[str, torch.Tensor]:\n\n        # Pad `caption_tokens` and `masked_labels` up to this length.\n        caption_tokens = torch.nn.utils.rnn.pad_sequence(\n            [d[\"caption_tokens\"] for d in data],\n            batch_first=True,\n            padding_value=self.padding_idx,\n        )\n        noitpac_tokens = torch.nn.utils.rnn.pad_sequence(\n            [d[\"noitpac_tokens\"] for d in data],\n            batch_first=True,\n            padding_value=self.padding_idx,\n        )\n        return {\n            \"image_id\": torch.stack([d[\"image_id\"] for d in data], dim=0),\n            \"image\": torch.stack([d[\"image\"] for d in data], dim=0),\n            \"caption_tokens\": caption_tokens,\n            \"noitpac_tokens\": noitpac_tokens,\n            \"caption_lengths\": torch.stack([d[\"caption_lengths\"] for d in data]),\n        }\n"
  },
  {
    "path": "virtex/data/datasets/classification.py",
    "content": "from collections import defaultdict\nimport glob\nimport json\nimport os\nimport random\nfrom typing import Any, Callable, Dict, List, Tuple\n\nimport albumentations as alb\nimport cv2\nimport numpy as np\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom virtex.data.tokenizers import SentencePieceBPETokenizer\nfrom virtex.data import transforms as T\nfrom .coco_captions import CocoCaptionsDataset\n\n\nclass TokenClassificationDataset(Dataset):\n    r\"\"\"\n    A dataset which provides image-labelset pairs from a COCO Captions annotation\n    file. The set of caption tokens (unordered) is treated as a labelset.\n\n    Args:\n        data_root: Path to dataset directory containing images and annotations.\n        split: Name of COCO 2017 split to read. One of ``{\"train\", \"val\"}``.\n        tokenizer: Tokenizer which maps word tokens to their integer IDs.\n        image_transform: List of image transformations, from either\n            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_\n            or :mod:`virtex.data.transforms`.\n        max_caption_length: Maximum number of tokens to keep in caption tokens.\n            Extra tokens will be trimmed from the right end of the token list.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_root: str,\n        split: str,\n        tokenizer: SentencePieceBPETokenizer,\n        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,\n        max_caption_length: int = 30,\n    ):\n        self._dset = CocoCaptionsDataset(data_root, split)\n        self.image_transform = image_transform\n        self.max_caption_length = max_caption_length\n\n        # Short handles for common tokens for convenience:\n        self.padding_idx = tokenizer.token_to_id(\"<unk>\")\n        self.sos_id = tokenizer.token_to_id(\"[SOS]\")\n        self.eos_id = tokenizer.token_to_id(\"[EOS]\")\n\n    def __len__(self):\n        return len(self._dset)\n\n    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:\n\n        # keys: {\"image_id\", \"image\", \"captions\"}\n        instance = self._dset[idx]\n        image_id, image, captions = (\n            instance[\"image_id\"],\n            instance[\"image\"],\n            instance[\"captions\"],\n        )\n        caption = random.choice(captions)\n\n        # Transform image-caption pair and convert image from HWC to CHW format.\n        # Pass in caption to image_transform due to paired horizontal flip.\n        # Caption won't be tokenized/processed here.\n        image_caption = self.image_transform(image=image, caption=caption)\n        image, caption = image_caption[\"image\"], image_caption[\"caption\"]\n        image = np.transpose(image, (2, 0, 1))\n\n        caption_tokens = [self.sos_id, *self.tokenizer.encode(caption), self.eos_id]\n        caption_tokens = caption_tokens[: self.max_caption_length]\n        return {\n            \"image_id\": torch.tensor(image_id, dtype=torch.long),\n            \"image\": torch.tensor(image, dtype=torch.float),\n            \"labels\": torch.tensor(caption_tokens, dtype=torch.long),\n        }\n\n    def collate_fn(\n        self, data: List[Dict[str, torch.Tensor]]\n    ) -> Dict[str, torch.Tensor]:\n\n        labels = torch.nn.utils.rnn.pad_sequence(\n            [d[\"labels\"] for d in data],\n            batch_first=True,\n            padding_value=self.padding_idx,\n        )\n        return {\n            \"image_id\": torch.stack([d[\"image_id\"] for d in data], dim=0),\n            \"image\": torch.stack([d[\"image\"] for d in data], dim=0),\n            \"labels\": labels,\n        }\n\n\nclass MultiLabelClassificationDataset(Dataset):\n    r\"\"\"\n    A dataset which provides image-labelset pairs from COCO instance annotation\n    files. This is used for multilabel classification pretraining task.\n\n    Args:\n        data_root: Path to dataset directory containing images and annotations.\n        split: Name of COCO 2017 split to read. One of ``{\"train\", \"val\"}``.\n        image_transform: List of image transformations, from either\n            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_\n            or :mod:`virtex.data.transforms`.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_root: str,\n        split: str,\n        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,\n    ):\n        self.image_transform = image_transform\n\n        # Make a tuple of image id and its filename, get image_id from its\n        # filename (assuming directory has images with names in COCO 2017 format).\n        image_filenames = glob.glob(os.path.join(data_root, f\"{split}2017\", \"*.jpg\"))\n        self.id_filename: List[Tuple[int, str]] = [\n            (int(os.path.basename(name)[:-4]), name) for name in image_filenames\n        ]\n        # Load the instance (bounding box and mask) annotations.\n        _annotations = json.load(\n            open(os.path.join(data_root, \"annotations\", f\"instances_{split}2017.json\"))\n        )\n        # Make a mapping between COCO category id and its index, to make IDs\n        # consecutive, else COCO has 80 classes with IDs 1-90. Start index from 1\n        # as 0 is reserved for background (padding idx).\n        _category_ids = {\n            ann[\"id\"]: index + 1 for index, ann in enumerate(_annotations[\"categories\"])\n        }\n        # Mapping from image ID to list of unique category IDs (indices as above)\n        # in corresponding image.\n        self._labels: Dict[str, Any] = defaultdict(list)\n\n        for ann in _annotations[\"annotations\"]:\n            self._labels[ann[\"image_id\"]].append(_category_ids[ann[\"category_id\"]])\n\n        # De-duplicate and drop empty labels, we only need to do classification.\n        self._labels = {\n            _id: list(set(lbl)) for _id, lbl in self._labels.items() if len(lbl) > 0\n        }\n        # Filter out image IDs which didn't have any labels.\n        self.id_filename = [\n            (t[0], t[1]) for t in self.id_filename if t[0] in self._labels\n        ]\n        # Padding while forming a batch, because images may have variable number\n        # of instances. We do not need padding index from tokenizer: COCO has\n        # category ID 0 as background, conventionally.\n        self.padding_idx = 0\n\n    def __len__(self):\n        return len(self.id_filename)\n\n    def __getitem__(self, idx: int):\n        # Get image ID and filename.\n        image_id, filename = self.id_filename[idx]\n\n        # Open image from path and apply transformation, convert to CHW format.\n        image = cv2.imread(filename)\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n        image = self.image_transform(image=image)[\"image\"]\n        image = np.transpose(image, (2, 0, 1))\n\n        # Get a list of instances present in the image.\n        labels = self._labels[image_id]\n\n        return {\n            \"image_id\": torch.tensor(image_id, dtype=torch.long),\n            \"image\": torch.tensor(image, dtype=torch.float),\n            \"labels\": torch.tensor(labels, dtype=torch.long),\n        }\n\n    def collate_fn(\n        self, data: List[Dict[str, torch.Tensor]]\n    ) -> Dict[str, torch.Tensor]:\n\n        labels = torch.nn.utils.rnn.pad_sequence(\n            [d[\"labels\"] for d in data],\n            batch_first=True,\n            padding_value=self.padding_idx,\n        )\n        return {\n            \"image_id\": torch.stack([d[\"image_id\"] for d in data], dim=0),\n            \"image\": torch.stack([d[\"image\"] for d in data], dim=0),\n            \"labels\": labels,\n        }\n"
  },
  {
    "path": "virtex/data/datasets/coco_captions.py",
    "content": "from collections import defaultdict\nimport json\nimport os\nimport unicodedata\nfrom typing import Dict, List\n\nimport cv2\nfrom torch.utils.data import Dataset\n\n\nclass CocoCaptionsDataset(Dataset):\n    r\"\"\"\n    A PyTorch dataset to read COCO Captions dataset and provide it completely\n    unprocessed. This dataset is used by various task-specific datasets\n    in :mod:`~virtex.data.datasets` module.\n\n    Args:\n        data_root: Path to the COCO dataset root directory.\n        split: Name of COCO 2017 split to read. One of ``{\"train\", \"val\"}``.\n    \"\"\"\n\n    def __init__(self, data_root: str, split: str):\n\n        # Get paths to image directory and annotation file.\n        image_dir = os.path.join(data_root, f\"{split}2017\")\n        captions = json.load(\n            open(os.path.join(data_root, \"annotations\", f\"captions_{split}2017.json\"))\n        )\n        # Collect list of captions for each image.\n        captions_per_image: Dict[int, List[str]] = defaultdict(list)\n\n        for ann in captions[\"annotations\"]:\n            # Perform common normalization (lowercase, trim spaces, NKFC strip\n            # accents and NKFC normalization).\n            caption = ann[\"caption\"].lower()\n            caption = unicodedata.normalize(\"NFKD\", caption)\n            caption = \"\".join([chr for chr in caption if not unicodedata.combining(chr)])\n\n            captions_per_image[ann[\"image_id\"]].append(caption)\n\n        # Collect image file for each image (by its ID).\n        image_filepaths: Dict[int, str] = {\n            im[\"id\"]: os.path.join(image_dir, im[\"file_name\"])\n            for im in captions[\"images\"]\n        }\n        # Keep all annotations in memory. Make a list of tuples, each tuple\n        # is ``(image_id, file_path, list[captions])``.\n        self.instances = [\n            (im_id, image_filepaths[im_id], captions_per_image[im_id])\n            for im_id in captions_per_image.keys()\n        ]\n\n    def __len__(self):\n        return len(self.instances)\n\n    def __getitem__(self, idx: int):\n        image_id, image_path, captions = self.instances[idx]\n\n        # shape: (height, width, channels), dtype: uint8\n        image = cv2.imread(image_path)\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n\n        return {\"image_id\": image_id, \"image\": image, \"captions\": captions}\n"
  },
  {
    "path": "virtex/data/datasets/downstream.py",
    "content": "from collections import defaultdict\nimport glob\nimport json\nimport os\nfrom typing import Callable, Dict, List, Tuple\n\nimport cv2\nimport numpy as np\nimport torch\nfrom torch.utils.data import Dataset\nfrom torchvision.datasets import ImageNet\n\nfrom virtex.data import transforms as T\n\n\nclass ImageNetDataset(ImageNet):\n    r\"\"\"\n    Simple wrapper over torchvision's ImageNet dataset. Image transform is\n    handled here instead of passing to super class.\n\n    Args:\n        data_root: Path to the ImageNet dataset directory.\n        split: Which split to read from. One of ``{\"train\", \"val\"}``.\n        image_transform: List of image transformations, from either\n            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_\n            or :mod:`virtex.data.transforms`.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_root: str = \"datasets/imagenet\",\n        split: str = \"train\",\n        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,\n    ):\n        super().__init__(data_root, split)\n        self.image_transform = image_transform\n\n    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:\n        image, label = super().__getitem__(idx)\n\n        # Apply transformation to  image and convert to CHW format.\n        image = self.image_transform(image=np.array(image))[\"image\"]\n        image = np.transpose(image, (2, 0, 1))\n        return {\n            \"image\": torch.tensor(image, dtype=torch.float),\n            \"label\": torch.tensor(label, dtype=torch.long),\n        }\n\n    @staticmethod\n    def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:\n        return {\n            \"image\": torch.stack([d[\"image\"] for d in data], dim=0),\n            \"label\": torch.stack([d[\"label\"] for d in data], dim=0),\n        }\n\n\nclass INaturalist2018Dataset(Dataset):\n    r\"\"\"\n    A dataset which provides image-label pairs from the iNaturalist 2018 dataset.\n\n    Args:\n        data_root: Path to the iNaturalist 2018 dataset directory.\n        split: Which split to read from. One of ``{\"train\", \"val\"}``.\n        image_transform: List of image transformations, from either\n            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_\n            or :mod:`virtex.data.transforms`.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_root: str = \"datasets/inaturalist\",\n        split: str = \"train\",\n        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,\n    ):\n        self.split = split\n        self.image_transform = image_transform\n\n        annotations = json.load(\n            open(os.path.join(data_root, \"annotations\", f\"{split}2018.json\"))\n        )\n        # Make a list of image IDs to file paths.\n        self.image_id_to_file_path = {\n            ann[\"id\"]: os.path.join(data_root, ann[\"file_name\"])\n            for ann in annotations[\"images\"]\n        }\n        # For a list of instances: (image_id, category_id) tuples.\n        self.instances = [\n            (ann[\"image_id\"], ann[\"category_id\"])\n            for ann in annotations[\"annotations\"]\n        ]\n\n    def __len__(self):\n        return len(self.instances)\n\n    def __getitem__(self, idx: int):\n        image_id, label = self.instances[idx]\n        image_path = self.image_id_to_file_path[image_id]\n\n        # Open image from path and apply transformation, convert to CHW format.\n        image = cv2.imread(image_path)\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n        image = self.image_transform(image=image)[\"image\"]\n        image = np.transpose(image, (2, 0, 1))\n\n        return {\n            \"image\": torch.tensor(image, dtype=torch.float),\n            \"label\": torch.tensor(label, dtype=torch.long),\n        }\n\n    @staticmethod\n    def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:\n        return {\n            \"image\": torch.stack([d[\"image\"] for d in data], dim=0),\n            \"label\": torch.stack([d[\"label\"] for d in data], dim=0),\n        }\n\n\nclass VOC07ClassificationDataset(Dataset):\n    r\"\"\"\n    A dataset which provides image-label pairs from the PASCAL VOC 2007 dataset.\n\n    Args:\n        data_root: Path to VOC 2007 directory containing sub-directories named\n            ``Annotations``, ``ImageSets``, and ``JPEGImages``.\n        split: Which split to read from. One of ``{\"trainval\", \"test\"}``.\n        image_transform: List of image transformations, from either\n            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_\n            or :mod:`virtex.data.transforms`.\n    \"\"\"\n\n    def __init__(\n        self,\n        data_root: str = \"datasets/VOC2007\",\n        split: str = \"trainval\",\n        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,\n    ):\n        self.split = split\n        self.image_transform = image_transform\n\n        ann_paths = sorted(\n            glob.glob(os.path.join(data_root, \"ImageSets\", \"Main\", f\"*_{split}.txt\"))\n        )\n        # A list like; [\"aeroplane\", \"bicycle\", \"bird\", ...]\n        self.class_names = [\n            os.path.basename(path).split(\"_\")[0] for path in ann_paths\n        ]\n\n        # We will construct a map for image name to a list of\n        # shape: (num_classes, ) and values as one of {-1, 0, 1}.\n        # 1: present, -1: not present, 0: ignore.\n        image_names_to_labels: Dict[str, torch.Tensor] = defaultdict(\n            lambda: -torch.ones(len(self.class_names), dtype=torch.int32)\n        )\n        for cls_num, ann_path in enumerate(ann_paths):\n            with open(ann_path, \"r\") as fopen:\n                for line in fopen:\n                    img_name, orig_label_str = line.strip().split()\n                    orig_label = int(orig_label_str)\n\n                    # In VOC data, -1 (not present): set to 0 as train target\n                    # In VOC data, 0 (ignore): set to -1 as train target.\n                    orig_label = (\n                        0 if orig_label == -1 else -1 if orig_label == 0 else 1\n                    )\n                    image_names_to_labels[img_name][cls_num] = orig_label\n\n        # Convert the dict to a list of tuples for easy indexing.\n        # Replace image name with full image path.\n        self.instances: List[Tuple[str, torch.Tensor]] = [\n            (\n                os.path.join(data_root, \"JPEGImages\", f\"{image_name}.jpg\"),\n                label.tolist(),\n            )\n            for image_name, label in image_names_to_labels.items()\n        ]\n\n    def __len__(self):\n        return len(self.instances)\n\n    def __getitem__(self, idx: int):\n        image_path, label = self.instances[idx]\n\n        # Open image from path and apply transformation, convert to CHW format.\n        image = cv2.imread(image_path)\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n        image = self.image_transform(image=image)[\"image\"]\n        image = np.transpose(image, (2, 0, 1))\n\n        return {\n            \"image\": torch.tensor(image, dtype=torch.float),\n            \"label\": torch.tensor(label, dtype=torch.long),\n        }\n\n    @staticmethod\n    def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:\n        return {\n            \"image\": torch.stack([d[\"image\"] for d in data], dim=0),\n            \"label\": torch.stack([d[\"label\"] for d in data], dim=0),\n        }\n\n\nclass ImageDirectoryDataset(Dataset):\n    r\"\"\"\n    A dataset which reads images from any directory. This class is useful to\n    run image captioning inference on our models with any arbitrary images.\n\n    Args:\n        data_root: Path to a directory containing images.\n        image_transform: List of image transformations, from either\n            `albumentations <https://albumentations.readthedocs.io/en/latest/>`_\n            or :mod:`virtex.data.transforms`.\n    \"\"\"\n\n    def __init__(\n        self, data_root: str, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM\n    ):\n        self.image_paths = glob.glob(os.path.join(data_root, \"*\"))\n        self.image_transform = image_transform\n\n    def __len__(self):\n        return len(self.image_paths)\n\n    def __getitem__(self, idx: int):\n        image_path = self.image_paths[idx]\n        # Remove extension from image name to use as image_id.\n        image_id = os.path.splitext(os.path.basename(image_path))[0]\n\n        # Open image from path and apply transformation, convert to CHW format.\n        image = cv2.imread(image_path)\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n        image = self.image_transform(image=image)[\"image\"]\n        image = np.transpose(image, (2, 0, 1))\n\n        # Return image id as string so collate_fn does not cast to torch.tensor.\n        return {\"image_id\": str(image_id), \"image\": torch.tensor(image)}\n"
  },
  {
    "path": "virtex/data/datasets/masked_lm.py",
    "content": "import math\nimport random\nfrom typing import Callable, Dict, List\n\nimport albumentations as alb\nimport numpy as np\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom virtex.data.tokenizers import SentencePieceBPETokenizer\nfrom virtex.data import transforms as T\nfrom .coco_captions import CocoCaptionsDataset\n\n\nclass MaskedLmDataset(Dataset):\n    def __init__(\n        self,\n        data_root: str,\n        split: str,\n        tokenizer: SentencePieceBPETokenizer,\n        image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM,\n        max_caption_length: int = 30,\n        mask_proportion: float = 0.15,\n        mask_probability: float = 0.80,\n        replace_probability: float = 0.10,\n    ):\n        self._dset = CocoCaptionsDataset(data_root, split)\n        self.tokenizer = tokenizer\n        self.image_transform = image_transform\n        self.max_caption_length = max_caption_length\n\n        # Short handles for common tokens for convenience:\n        self.padding_idx = tokenizer.token_to_id(\"<unk>\")\n        self.sos_id = tokenizer.token_to_id(\"[SOS]\")\n        self.eos_id = tokenizer.token_to_id(\"[EOS]\")\n        self.mask_id = tokenizer.token_to_id(\"[MASK]\")\n\n        self._vocab_size = tokenizer.get_vocab_size()\n        self._mask_proportion = mask_proportion\n        self._mask_prob = mask_probability\n        self._repl_prob = replace_probability\n\n    def __len__(self):\n        return len(self._dset)\n\n    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:\n\n        # keys: {\"image_id\", \"image\", \"captions\"}\n        instance = self._dset[idx]\n        image_id, image, captions = (\n            instance[\"image_id\"],\n            instance[\"image\"],\n            instance[\"captions\"],\n        )\n        caption = random.choice(captions)\n\n        # Transform image-caption pair and convert image from HWC to CHW format.\n        # Pass in caption to image_transform due to paired horizontal flip.\n        # Caption won't be tokenized/processed here.\n        image_caption = self.image_transform(image=image, caption=caption)\n        image, caption = image_caption[\"image\"], image_caption[\"caption\"]\n        image = np.transpose(image, (2, 0, 1))\n\n        caption_tokens = [self.sos_id, *self.tokenizer.encode(caption), self.eos_id]\n        caption_tokens = caption_tokens[: self.max_caption_length]\n        # ---------------------------------------------------------------------\n        #  Mask some tokens randomly.\n        # ---------------------------------------------------------------------\n        masked_labels = [self.padding_idx] * len(caption_tokens)\n\n        # Indices in `caption_tokens` list to mask (minimum 1 index).\n        # Leave out first and last indices (boundary tokens).\n        tokens_to_mask: List[int] = random.sample(\n            list(range(1, len(caption_tokens) - 1)),\n            math.ceil((len(caption_tokens) - 2) * self._mask_proportion),\n        )\n        for i in tokens_to_mask:\n            # Whether to replace with [MASK] or random word.\n            # If only one token, always [MASK].\n            if len(tokens_to_mask) == 1:\n                masked_labels[i] = caption_tokens[i]\n                caption_tokens[i] = self.mask_id\n            else:\n                _flag: float = random.random()\n                if _flag <= self._mask_prob + self._repl_prob:\n                    if _flag <= self._mask_prob:\n                        masked_labels[i] = caption_tokens[i]\n                        caption_tokens[i] = self.mask_id\n                    else:\n                        caption_tokens[i] = self._random_token_index()\n        # ---------------------------------------------------------------------\n\n        return {\n            \"image_id\": torch.tensor(image_id, dtype=torch.long),\n            \"image\": torch.tensor(image, dtype=torch.float),\n            \"caption_tokens\": torch.tensor(caption_tokens, dtype=torch.long),\n            \"masked_labels\": torch.tensor(masked_labels, dtype=torch.long),\n            \"caption_lengths\": torch.tensor(len(caption_tokens), dtype=torch.long),\n        }\n\n    def collate_fn(\n        self, data: List[Dict[str, torch.Tensor]]\n    ) -> Dict[str, torch.Tensor]:\n\n        # Pad `caption_tokens` and `masked_labels` up to this length.\n        caption_tokens = torch.nn.utils.rnn.pad_sequence(\n            [d[\"caption_tokens\"] for d in data],\n            batch_first=True,\n            padding_value=self.padding_idx,\n        )\n        masked_labels = torch.nn.utils.rnn.pad_sequence(\n            [d[\"masked_labels\"] for d in data],\n            batch_first=True,\n            padding_value=self.padding_idx,\n        )\n        return {\n            \"image_id\": torch.stack([d[\"image_id\"] for d in data], dim=0),\n            \"image\": torch.stack([d[\"image\"] for d in data], dim=0),\n            \"caption_tokens\": caption_tokens,\n            \"masked_labels\": masked_labels,\n            \"caption_lengths\": torch.stack([d[\"caption_lengths\"] for d in data]),\n        }\n\n    def _random_token_index(self) -> int:\n        return random.randint(0, self._vocab_size - 1)\n"
  },
  {
    "path": "virtex/data/tokenizers.py",
    "content": "from typing import Any, Dict, List\n\nimport sentencepiece as sp\n\n\nclass SentencePieceBPETokenizer:\n    r\"\"\"\n    A tokenizer based on `SentencePiece <https://github.com/google/sentencepiece>`_\n    with BPE sub-routine. It encodes caption strings into list of tokens.\n\n    Args:\n        model_path: Path to the ``.model`` file trained by SentencePiece.\n    \"\"\"\n    SP_SPACE = u\"▁\"\n\n    def __init__(self, model_path: str):\n        self.model_path = model_path\n\n        # Load pretrained tokenizer model.\n        self.model = sp.SentencePieceProcessor()\n        self.model.Load(model_path)\n\n    def __getstate__(self):\n        r\"\"\"\n        This magic method, along with ``__setstate__`` makes an object of this\n        class picklable (and usable while data loading with multiple workers).\n        \"\"\"\n        state_dict = self.__dict__.copy()\n        state_dict[\"model\"] = None\n        return state_dict\n\n    def __setstate__(self, state_dict: Dict[str, Any]):\n        self.__dict__ = state_dict\n\n        self.model = sp.SentencePieceProcessor()\n        self.model.Load(self.model_path)\n\n    def get_vocab_size(self) -> int:\n        r\"\"\"Return number of tokens in vocabulary (including special tokens).\"\"\"\n        return len(self.model)\n\n    def token_to_id(self, token: str) -> int:\n        r\"\"\"Get integer ID of a string token (``<unk>`` if does not exist).\"\"\"\n        # Since tokenizer uses subword regularization, one token may break down to multiple IDs.\n        # Keep trying till we get a single ID.\n        return self.model.piece_to_id(token)\n\n    def id_to_token(self, token_id: int) -> str:\n        r\"\"\"Get string token of an integer ID (``<unk>`` if does not exist).\"\"\"\n        return self.model.id_to_piece(token_id)\n\n    def encode(self, text: str) -> List[int]:\n        r\"\"\"Convert a text string to a list of integer token ids.\"\"\"\n        return self.model.EncodeAsIds(text)\n\n    def decode(self, token_ids: List[int]) -> str:\n        r\"\"\"Convert a sequence of token IDs to a text string.\"\"\"\n        return self.model.DecodeIds(token_ids)\n"
  },
  {
    "path": "virtex/data/transforms.py",
    "content": "import albumentations as alb\nimport cv2\n\n\nclass HorizontalFlip(alb.BasicTransform):\n    r\"\"\"\n    Flip the image horizontally randomly (equally likely) and replace the\n    word \"left\" with \"right\" in the caption.\n\n    .. note::\n\n        This transform can also work on images only (without the captions).\n        Its behavior will be same as albumentations\n        :class:`~albumentations.augmentations.transforms.HorizontalFlip`.\n\n    Examples:\n        >>> flip = HorizontalFlip(p=0.5)\n        >>> out1 = flip(image=image, caption=caption)  # keys: {\"image\", \"caption\"}\n        >>> # Also works with images (without caption).\n        >>> out2 = flip(image=image)  # keys: {\"image\"}\n\n    \"\"\"\n\n    @property\n    def targets(self):\n        return {\"image\": self.apply, \"caption\": self.apply_to_caption}\n\n    def apply(self, img, **params):\n        return cv2.flip(img, 1)\n\n    def apply_to_caption(self, caption, **params):\n        caption = (\n            caption.replace(\"left\", \"[TMP]\")\n            .replace(\"right\", \"left\")\n            .replace(\"[TMP]\", \"right\")\n        )\n        return caption\n\n\nclass RandomResizedSquareCrop(alb.RandomResizedCrop):\n    r\"\"\"\n    A variant of :class:`albumentations.augmentations.transforms.RandomResizedCrop`\n    which assumes a square crop (width = height). Everything else is same.\n\n    Args:\n        size: Dimension of the width and height of the cropped image.\n    \"\"\"\n\n    def __init__(self, size: int, *args, **kwargs):\n        super().__init__(height=size, width=size, *args, **kwargs)\n\n\nclass CenterSquareCrop(alb.CenterCrop):\n    r\"\"\"\n    A variant of :class:`albumentations.augmentations.transforms.CenterCrop`\n    which assumes a square crop (width = height). Everything else is same.\n\n    Args:\n        size: Dimension of the width and height of the cropped image.\n    \"\"\"\n\n    def __init__(self, size: int, *args, **kwargs):\n        super().__init__(height=size, width=size, *args, **kwargs)\n\n\nclass SquareResize(alb.Resize):\n    r\"\"\"\n    A variant of :class:`albumentations.augmentations.transforms.Resize` which\n    assumes a square resize (width = height). Everything else is same.\n\n    Args:\n        size: Dimension of the width and height of the cropped image.\n    \"\"\"\n\n    def __init__(self, size: int, *args, **kwargs):\n        super().__init__(height=size, width=size, *args, **kwargs)\n\n\n# =============================================================================\n#   SOME COMMON CONSTANTS AND IMAGE TRANSFORMS:\n#   These serve as references here, and are used as default params in many\n#   dataset class constructors.\n# -----------------------------------------------------------------------------\n\nIMAGENET_COLOR_MEAN = (0.485, 0.456, 0.406)\nr\"\"\"ImageNet color normalization mean in RGB format (values in 0-1).\"\"\"\n\nIMAGENET_COLOR_STD = (0.229, 0.224, 0.225)\nr\"\"\"ImageNet color normalization std in RGB format (values in 0-1).\"\"\"\n\nDEFAULT_IMAGE_TRANSFORM = alb.Compose(\n    [\n        alb.SmallestMaxSize(256, p=1.0),\n        CenterSquareCrop(224, p=1.0),\n        alb.Normalize(mean=IMAGENET_COLOR_MEAN, std=IMAGENET_COLOR_STD, p=1.0),\n    ]\n)\nr\"\"\"Default transform without any data augmentation (during pretraining).\"\"\"\n# =============================================================================\n"
  },
  {
    "path": "virtex/factories.py",
    "content": "r\"\"\"\nThis module is a collection of *factories* for creating objects of datasets,\nmodels, optimizers and other useful components. For example, a ResNet-50\nvisual backbone can be created as:\n\n    .. code-block:: python\n\n        >>> # Explicitly by name, args and kwargs:\n        >>> backbone = VisualBackboneFactory.create(\n        ...     \"torchvision::resnet50\", pretrained=False\n        ... )\n        >>> # Directly from a config object:\n        >>> _C = Config(override_list=[\"MODEL.VISUAL.NAME\", \"torchvision::resnet50\"])\n        >>> backbone = VisualBackboneFactory.from_config(_C)\n\nCreating directly from :class:`~virtex.config.Config` is fast and simple, and\nensures minimal changes throughout the codebase upon any change in the call\nsignature of underlying class; or config hierarchy. Refer description of\nspecific factories for more details.\n\"\"\"\nimport re\nfrom functools import partial\nfrom typing import Any, Callable, Dict, Iterable, List\n\nimport albumentations as alb\nfrom torch import nn, optim\n\nimport virtex.data as vdata\nimport virtex.models as vmodels\nfrom virtex.config import Config\nfrom virtex.data import transforms as T\nfrom virtex.data.tokenizers import SentencePieceBPETokenizer\nfrom virtex.modules import visual_backbones, textual_heads\nfrom virtex.optim import Lookahead, lr_scheduler\n\nfrom virtex.utils.beam_search import AutoRegressiveBeamSearch\nfrom virtex.utils.nucleus_sampling import AutoRegressiveNucleusSampling\n\n\nclass Factory:\n    r\"\"\"\n    Base class for all factories. All factories must inherit this base class\n    and follow these guidelines for a consistent behavior:\n\n    * Factory objects cannot be instantiated, doing ``factory = SomeFactory()``\n      is illegal. Child classes should not implement ``__init__`` methods.\n    * All factories must have an attribute named ``PRODUCTS`` of type\n      ``Dict[str, Callable]``, which associates each class with a unique string\n      name which can be used to create it.\n    * All factories must implement one classmethod, :meth:`from_config` which\n      contains logic for creating an object directly by taking name and other\n      arguments directly from :class:`~virtex.config.Config`. They can use\n      :meth:`create` already implemented in this base class.\n    * :meth:`from_config` should not use too many extra arguments than the\n      config itself, unless necessary (such as model parameters for optimizer).\n    \"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {}\n\n    def __init__(self):\n        raise ValueError(\n            f\"\"\"Cannot instantiate {self.__class__.__name__} object, use\n            `create` classmethod to create a product from this factory.\n            \"\"\"\n        )\n\n    @classmethod\n    def create(cls, name: str, *args, **kwargs) -> Any:\n        r\"\"\"Create an object by its name, args and kwargs.\"\"\"\n        if name not in cls.PRODUCTS:\n            raise KeyError(f\"{cls.__class__.__name__} cannot create {name}.\")\n\n        return cls.PRODUCTS[name](*args, **kwargs)\n\n    @classmethod\n    def from_config(cls, config: Config) -> Any:\n        r\"\"\"Create an object directly from config.\"\"\"\n        raise NotImplementedError\n\n\nclass TokenizerFactory(Factory):\n    r\"\"\"\n    Factory to create text tokenizers. This codebase ony supports one tokenizer\n    for now, but having a dedicated factory makes it easy to add more if needed.\n\n    Possible choices: ``{\"SentencePieceBPETokenizer\"}``.\n    \"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {\n        \"SentencePieceBPETokenizer\": SentencePieceBPETokenizer\n    }\n\n    @classmethod\n    def from_config(cls, config: Config) -> SentencePieceBPETokenizer:\n        r\"\"\"\n        Create a tokenizer directly from config.\n\n        Args:\n            config: Config object with all the parameters.\n        \"\"\"\n\n        _C = config\n\n        tokenizer = cls.create(\n            \"SentencePieceBPETokenizer\",\n            model_path=_C.DATA.TOKENIZER_MODEL,\n        )\n        return tokenizer\n\n\nclass ImageTransformsFactory(Factory):\n    r\"\"\"\n    Factory to create image transformations for common preprocessing and data\n    augmentations. These are a mix of default transformations from\n    `albumentations <https://albumentations.readthedocs.io/en/latest/>`_ and\n    some extended ones defined in :mod:`virtex.data.transforms`.\n\n    It uses sensible default values, however they can be provided with the name\n    in dict syntax. Example: ``random_resized_crop::{'scale': (0.08, 1.0)}``\n\n    .. note::\n\n        This factory does not implement :meth:`from_config` method. It is only\n        used by :class:`PretrainingDatasetFactory` and\n        :class:`DownstreamDatasetFactory`.\n\n    Possible choices: ``{\"center_crop\", \"horizontal_flip\", \"random_resized_crop\",\n    \"normalize\", \"global_resize\", \"color_jitter\", \"smallest_resize\"}``.\n    \"\"\"\n\n    # fmt: off\n    PRODUCTS: Dict[str, Callable] = {\n        # Input resize transforms: whenever selected, these are always applied.\n        # These transforms require one position argument: image dimension.\n        \"random_resized_crop\": partial(\n            T.RandomResizedSquareCrop, scale=(0.2, 1.0), ratio=(0.75, 1.333), p=1.0\n        ),\n        \"center_crop\": partial(T.CenterSquareCrop, p=1.0),\n        \"smallest_resize\": partial(alb.SmallestMaxSize, p=1.0),\n        \"global_resize\": partial(T.SquareResize, p=1.0),\n\n        # Keep hue limits small in color jitter because it changes color drastically\n        # and captions often mention colors. Apply with higher probability.\n        \"color_jitter\": partial(\n            alb.ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8\n        ),\n        \"horizontal_flip\": partial(T.HorizontalFlip, p=0.5),\n\n        # Color normalization: whenever selected, always applied. This accepts images\n        # in [0, 255], requires mean and std in [0, 1] and normalizes to `N(0, 1)`.\n        \"normalize\": partial(\n            alb.Normalize, mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, p=1.0\n        ),\n    }\n    # fmt: on\n\n    @classmethod\n    def create(cls, name: str, *args, **kwargs) -> Any:\n        r\"\"\"Create an object by its name, args and kwargs.\"\"\"\n\n        if \"::\" in name:\n            name, __kwargs = name.split(\"::\")\n            _kwargs = eval(__kwargs)\n        else:\n            _kwargs = {}\n\n        _kwargs.update(kwargs)\n        return super().create(name, *args, **_kwargs)\n\n    @classmethod\n    def from_config(cls, config: Config):\n        r\"\"\"Augmentations cannot be created from config, only :meth:`create`.\"\"\"\n        raise NotImplementedError\n\n\nclass PretrainingDatasetFactory(Factory):\n    r\"\"\"\n    Factory to create :class:`~torch.utils.data.Dataset` s for pretraining\n    VirTex models. Datasets are created depending on pretraining task used.\n    Typically these datasets either provide image-caption pairs, or only images\n    from COCO Captions dataset (serialized to an LMDB file).\n\n    As an exception, the dataset for ``multilabel_classification`` provides\n    COCO images and labels of their bounding box annotations.\n\n    Possible choices: ``{\"bicaptioning\", \"captioning\", \"masked_lm\",\n    \"token_classification\", \"multilabel_classification\"}``.\n    \"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {\n        \"virtex\": vdata.CaptioningDataset,\n        \"bicaptioning\": vdata.CaptioningDataset,\n        \"captioning\": vdata.CaptioningDataset,\n        \"masked_lm\": vdata.MaskedLmDataset,\n        \"token_classification\": vdata.TokenClassificationDataset,\n        \"multilabel_classification\": vdata.MultiLabelClassificationDataset,\n    }\n\n    @classmethod\n    def from_config(cls, config: Config, split: str = \"train\"):\n        r\"\"\"\n        Create a dataset directly from config. Names in this factory match with\n        names in :class:`PretrainingModelFactory` because both use same config\n        parameter ``MODEL.NAME`` to create objects.\n\n        Args:\n            config: Config object with all the parameters.\n            split: Which dataset split to load. One of ``{\"train\", \"val\"}``.\n        \"\"\"\n\n        _C = config\n        # Every dataset needs these two args.\n        kwargs = {\"data_root\": _C.DATA.ROOT, \"split\": split}\n\n        # Create a list of image transformations based on transform names.\n        image_transform_list: List[Callable] = []\n\n        for name in getattr(_C.DATA, f\"IMAGE_TRANSFORM_{split.upper()}\"):\n            # Pass dimensions if cropping / resizing, else rely on the defaults\n            # as per `ImageTransformsFactory`.\n            if \"resize\" in name or \"crop\" in name:\n                image_transform_list.append(\n                    ImageTransformsFactory.create(name, _C.DATA.IMAGE_CROP_SIZE)\n                )\n            else:\n                image_transform_list.append(ImageTransformsFactory.create(name))\n\n        kwargs[\"image_transform\"] = alb.Compose(image_transform_list)\n\n        # Add dataset specific kwargs.\n        if _C.MODEL.NAME != \"multilabel_classification\":\n            tokenizer = TokenizerFactory.from_config(_C)\n            kwargs.update(\n                tokenizer=tokenizer,\n                max_caption_length=_C.DATA.MAX_CAPTION_LENGTH,\n            )\n\n        if _C.MODEL.NAME == \"masked_lm\":\n            kwargs.update(\n                mask_proportion=_C.DATA.MASKED_LM.MASK_PROPORTION,\n                mask_probability=_C.DATA.MASKED_LM.MASK_PROBABILITY,\n                replace_probability=_C.DATA.MASKED_LM.REPLACE_PROBABILITY,\n            )\n\n        # Dataset names match with model names (and ofcourse pretext names).\n        return cls.create(_C.MODEL.NAME, **kwargs)\n\n\nclass DownstreamDatasetFactory(Factory):\n    r\"\"\"\n    Factory to create :class:`~torch.utils.data.Dataset` s for evaluating\n    VirTex models on downstream tasks.\n\n    Possible choices: ``{\"datasets/VOC2007\", \"datasets/imagenet\"}``.\n    \"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {\n        \"datasets/VOC2007\": vdata.VOC07ClassificationDataset,\n        \"datasets/imagenet\": vdata.ImageNetDataset,\n        \"datasets/inaturalist\": vdata.INaturalist2018Dataset,\n    }\n\n    @classmethod\n    def from_config(cls, config: Config, split: str = \"train\"):\n        r\"\"\"\n        Create a dataset directly from config. Names in this factory are paths\n        of dataset directories (relative to the project directory), because\n        config parameter ``DATA.ROOT`` is used to create objects.\n\n        Args:\n            config: Config object with all the parameters.\n            split: Which dataset split to load. One of ``{\"trainval\", \"test\"}``\n                for VOC2007, or one of ``{\"train\", \"val\"}`` for ImageNet.\n        \"\"\"\n\n        _C = config\n        # Every dataset needs these two args.\n        kwargs = {\"data_root\": _C.DATA.ROOT, \"split\": split}\n\n        # For VOC2007, `IMAGE_TRANSFORM_TRAIN` is used for \"trainval\" split and\n        # `IMAGE_TRANSFORM_VAL` is used fo \"test\" split.\n        image_transform_names: List[str] = list(\n            _C.DATA.IMAGE_TRANSFORM_TRAIN\n            if \"train\" in split\n            else _C.DATA.IMAGE_TRANSFORM_VAL\n        )\n        # Create a list of image transformations based on names.\n        image_transform_list: List[Callable] = []\n\n        for name in image_transform_names:\n            # Pass dimensions for resize/crop, else rely on the defaults.\n            if name.split(\"::\")[0] in {\"random_resized_crop\", \"center_crop\", \"global_resize\"}:\n                transform = ImageTransformsFactory.create(name, 224)\n            elif name.split(\"::\")[0] in {\"smallest_resize\"}:\n                transform = ImageTransformsFactory.create(name, 256)\n            else:\n                transform = ImageTransformsFactory.create(name)\n\n            image_transform_list.append(transform)\n\n        kwargs[\"image_transform\"] = alb.Compose(image_transform_list)\n\n        return cls.create(_C.DATA.ROOT, **kwargs)\n\n\nclass VisualBackboneFactory(Factory):\n    r\"\"\"\n    Factory to create :mod:`~virtex.modules.visual_backbones`. This factory\n    supports any ResNet-like model from\n    `Torchvision <https://pytorch.org/docs/stable/torchvision/models.html>`_.\n    Use the method name for model as in torchvision, for example,\n    ``torchvision::resnet50``, ``torchvision::wide_resnet50_2`` etc.\n\n    Possible choices: ``{\"torchvision\"}``.\n    \"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {\n        \"torchvision\": visual_backbones.TorchvisionVisualBackbone,\n    }\n\n    @classmethod\n    def from_config(cls, config: Config) -> visual_backbones.VisualBackbone:\n        r\"\"\"\n        Create a visual backbone directly from config.\n\n        Args:\n            config: Config object with all the parameters.\n        \"\"\"\n\n        _C = config\n        kwargs = {\"visual_feature_size\": _C.MODEL.VISUAL.FEATURE_SIZE}\n\n        if \"torchvision\" in _C.MODEL.VISUAL.NAME:\n            # Check the name for models from torchvision.\n            cnn_name = _C.MODEL.VISUAL.NAME.split(\"::\")[-1]\n            kwargs[\"pretrained\"] = _C.MODEL.VISUAL.PRETRAINED\n            kwargs[\"frozen\"] = _C.MODEL.VISUAL.FROZEN\n\n            return cls.create(\"torchvision\", cnn_name, **kwargs)\n        else:\n            return cls.create(_C.MODEL.VISUAL.NAME, **kwargs)\n\n\nclass TextualHeadFactory(Factory):\n    r\"\"\"\n    Factory to create :mod:`~virtex.modules.textual_heads`. Architectural\n    hyperparameters for transformers can be specified as ``name::*``.\n    For example, ``transdec_postnorm::L1_H1024_A16_F4096`` would create a\n    transformer textual head with ``L = 1`` layers, ``H = 1024`` hidden size,\n    ``A = 16`` attention heads, and ``F = 4096`` size of feedforward layers.\n\n    Textual head should be ``\"none\"`` for pretraining tasks which do not\n    involve language modeling, such as ``\"token_classification\"``.\n\n    Possible choices: ``{\"transdec_postnorm\", \"transdec_prenorm\", \"none\"}``.\n    \"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {\n        \"transdec_prenorm\": partial(\n            textual_heads.TransformerDecoderTextualHead, norm_first=True\n        ),\n        \"transdec_postnorm\": partial(\n            textual_heads.TransformerDecoderTextualHead, norm_first=False\n        ),\n        \"none\": textual_heads.LinearTextualHead,\n    }\n\n    @classmethod\n    def from_config(cls, config: Config) -> nn.Module:\n        r\"\"\"\n        Create a textual head directly from config.\n\n        Args:\n            config: Config object with all the parameters.\n        \"\"\"\n\n        _C = config\n        name = _C.MODEL.TEXTUAL.NAME\n        kwargs = {\n            \"visual_feature_size\": _C.MODEL.VISUAL.FEATURE_SIZE,\n            \"vocab_size\": _C.DATA.VOCAB_SIZE,\n        }\n\n        if \"trans\" in _C.MODEL.TEXTUAL.NAME:\n            # Get architectural hyper-params as per name by matching regex.\n            name, architecture = name.split(\"::\")\n            architecture = re.match(r\"L(\\d+)_H(\\d+)_A(\\d+)_F(\\d+)\", architecture)\n\n            num_layers = int(architecture.group(1))\n            hidden_size = int(architecture.group(2))\n            attention_heads = int(architecture.group(3))\n            feedforward_size = int(architecture.group(4))\n\n            # Mask the future tokens for autoregressive captioning.\n            mask_future = _C.MODEL.NAME in {\"virtex\", \"captioning\", \"bicaptioning\"}\n\n            kwargs.update(\n                hidden_size=hidden_size,\n                num_layers=num_layers,\n                attention_heads=attention_heads,\n                feedforward_size=feedforward_size,\n                dropout=_C.MODEL.TEXTUAL.DROPOUT,\n                mask_future_positions=mask_future,\n                max_caption_length=_C.DATA.MAX_CAPTION_LENGTH,\n                padding_idx=_C.DATA.UNK_INDEX,\n            )\n        return cls.create(name, **kwargs)\n\n\nclass PretrainingModelFactory(Factory):\n    r\"\"\"\n    Factory to create :mod:`~virtex.models` for different pretraining tasks.\n\n    Possible choices: ``{\"bicaptioning\", \"captioning\", \"masked_lm\",\n    \"token_classification\", \"multilabel_classification\"}``.\n    \"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {\n        # First two are basically the same. Added for shorthand notation.\n        \"virtex\": vmodels.VirTexModel,\n        \"bicaptioning\": vmodels.BidirectionalCaptioningModel,\n        \"captioning\": vmodels.ForwardCaptioningModel,\n        \"masked_lm\": vmodels.MaskedLMModel,\n        \"token_classification\": vmodels.TokenClassificationModel,\n        \"multilabel_classification\": vmodels.MultiLabelClassificationModel,\n    }\n\n    @classmethod\n    def from_config(cls, config: Config) -> nn.Module:\n        r\"\"\"\n        Create a model directly from config.\n\n        Args:\n            config: Config object with all the parameters.\n        \"\"\"\n\n        _C = config\n\n        # Build visual and textual streams based on config.\n        visual = VisualBackboneFactory.from_config(_C)\n        textual = TextualHeadFactory.from_config(_C)\n\n        # Add model specific kwargs. Refer call signatures of specific models\n        # for matching kwargs here.\n        if _C.MODEL.NAME in {\"virtex\", \"captioning\", \"bicaptioning\"}:\n            kwargs = {\n                \"sos_index\": _C.DATA.SOS_INDEX,\n                \"eos_index\": _C.DATA.EOS_INDEX,\n                \"decoder\": CaptionDecoderFactory.from_config(_C),\n            }\n\n        elif _C.MODEL.NAME == \"token_classification\":\n            kwargs = {\n                \"ignore_indices\": [\n                    _C.DATA.UNK_INDEX,\n                    _C.DATA.SOS_INDEX,\n                    _C.DATA.EOS_INDEX,\n                    _C.DATA.MASK_INDEX,\n                ]\n            }\n        elif _C.MODEL.NAME == \"multilabel_classification\":\n            kwargs = {\"ignore_indices\": [0]}  # background index\n        else:\n            kwargs = {}\n\n        return cls.create(_C.MODEL.NAME, visual, textual, **kwargs)\n\n\nclass CaptionDecoderFactory(Factory):\n    r\"\"\"\n    Factory to create decoders from predicting captions from VirTex model.\n\n    Possible choices: ``{\"beam_search\", \"nucleus_sampling\"}``.\n    \"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {\n        \"beam_search\": AutoRegressiveBeamSearch,\n        \"nucleus_sampling\": AutoRegressiveNucleusSampling,\n    }\n\n    @classmethod\n    def from_config(cls, config: Config) -> nn.Module:\n        r\"\"\"\n        Create a model directly from config.\n\n        Args:\n            config: Config object with all the parameters.\n        \"\"\"\n\n        _C = config\n        kwargs = {\n            \"eos_index\": _C.DATA.EOS_INDEX,\n            \"max_steps\": _C.MODEL.DECODER.MAX_DECODING_STEPS,\n        }\n        if _C.MODEL.DECODER.NAME == \"beam_search\":\n            kwargs[\"beam_size\"] = _C.MODEL.DECODER.BEAM_SIZE\n        elif _C.MODEL.DECODER.NAME == \"nucleus_sampling\":\n            kwargs[\"nucleus_size\"] = _C.MODEL.DECODER.NUCLEUS_SIZE\n\n        return cls.create(_C.MODEL.DECODER.NAME, **kwargs)\n        \n        \nclass OptimizerFactory(Factory):\n    r\"\"\"Factory to create optimizers. Possible choices: ``{\"sgd\", \"adamw\"}``.\"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {\"sgd\": optim.SGD, \"adamw\": optim.AdamW}\n\n    @classmethod\n    def from_config(\n        cls, config: Config, named_parameters: Iterable[Any]\n    ) -> optim.Optimizer:\n        r\"\"\"\n        Create an optimizer directly from config.\n\n        Args:\n            config: Config object with all the parameters.\n            named_parameters: Named parameters of model (retrieved by\n                ``model.named_parameters()``) for the optimizer. We use named\n                parameters to set different LR and turn off weight decay for\n                certain parameters based on their names.\n        \"\"\"\n\n        _C = config\n\n        # Set different learning rate for CNN and rest of the model during\n        # pretraining. This doesn't matter for downstream evaluation because\n        # there are no modules with \"cnn\" in their name.\n        # Also turn off weight decay for layer norm and bias in textual stream.\n        param_groups = []\n        for name, param in named_parameters:\n            wd = 0.0 if re.match(_C.OPTIM.NO_DECAY, name) else _C.OPTIM.WEIGHT_DECAY\n            lr = _C.OPTIM.CNN_LR if \"cnn\" in name else _C.OPTIM.LR\n            param_groups.append({\"params\": [param], \"lr\": lr, \"weight_decay\": wd})\n\n        if _C.OPTIM.OPTIMIZER_NAME == \"sgd\":\n            kwargs = {\"momentum\": _C.OPTIM.SGD_MOMENTUM}\n        else:\n            kwargs = {}\n\n        optimizer = cls.create(_C.OPTIM.OPTIMIZER_NAME, param_groups, **kwargs)\n        if _C.OPTIM.LOOKAHEAD.USE:\n            optimizer = Lookahead(\n                optimizer, k=_C.OPTIM.LOOKAHEAD.STEPS, alpha=_C.OPTIM.LOOKAHEAD.ALPHA\n            )\n        return optimizer\n\n\nclass LRSchedulerFactory(Factory):\n    r\"\"\"\n    Factory to create LR schedulers. All schedulers have a built-in LR warmup\n    schedule before actual LR scheduling (decay) starts.\n\n    Possible choices: ``{\"none\", \"multistep\", \"linear\", \"cosine\"}``.\n    \"\"\"\n\n    PRODUCTS: Dict[str, Callable] = {\n        \"none\": lr_scheduler.LinearWarmupNoDecayLR,\n        \"multistep\": lr_scheduler.LinearWarmupMultiStepLR,\n        \"linear\": lr_scheduler.LinearWarmupLinearDecayLR,\n        \"cosine\": lr_scheduler.LinearWarmupCosineAnnealingLR,\n    }\n\n    @classmethod\n    def from_config(\n        cls, config: Config, optimizer: optim.Optimizer\n    ) -> optim.lr_scheduler.LambdaLR:\n        r\"\"\"\n        Create an LR scheduler directly from config.\n\n        Args:\n            config: Config object with all the parameters.\n            optimizer: Optimizer on which LR scheduling would be performed.\n        \"\"\"\n\n        _C = config\n        kwargs = {\n            \"total_steps\": _C.OPTIM.NUM_ITERATIONS,\n            \"warmup_steps\": _C.OPTIM.WARMUP_STEPS,\n        }\n        # Multistep LR requires multiplicative factor and milestones.\n        if _C.OPTIM.LR_DECAY_NAME == \"multistep\":\n            kwargs.update(gamma=_C.OPTIM.LR_GAMMA, milestones=_C.OPTIM.LR_STEPS)\n\n        return cls.create(_C.OPTIM.LR_DECAY_NAME, optimizer, **kwargs)\n"
  },
  {
    "path": "virtex/model_zoo/__init__.py",
    "content": "from .model_zoo import get\n\n__all__ = [\"get\"]\n"
  },
  {
    "path": "virtex/model_zoo/model_zoo.py",
    "content": "r\"\"\"\nA utility module to easily load common VirTex models (optionally with pretrained\nweights) using a single line of code.\n\nGet our full best performing VirTex model (with pretrained weights as):\n\n>>> import virtex.model_zoo as mz\n>>> model = mz.get(\"width_ablations/bicaptioning_R_50_L1_H2048.yaml\", pretrained=True)\n\nAny config available in ``configs/`` directory under project root can be\nspecified here, although this command need not be executed from project root.\nFor more details on available models, refer :doc:`usage/model_zoo`.\n\nPart of this code is adapted from Detectron2's model zoo; which was originally\nimplemented by the developers of this codebase, with reviews and further\nchanges by Detectron2 developers.\n\"\"\"\n# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\nimport os\nimport pkg_resources\n\nfrom fvcore.common.download import download\nimport torch\n\nfrom virtex.config import Config\nfrom virtex.factories import PretrainingModelFactory\nfrom virtex.utils.checkpointing import CheckpointManager\n\n\nclass _ModelZooUrls:\n    r\"\"\"Mapping from config names to URL suffixes of pretrained weights.\"\"\"\n\n    URL_PREFIX = \"https://www.dropbox.com/s\"\n\n    CONFIG_PATH_TO_DB_ID = {\n\n        # Pretraining Task Ablations\n        \"task_ablations/bicaptioning_R_50_L1_H2048.yaml\": \"mbeeso8wyieq8wy\",\n        \"task_ablations/captioning_R_50_L1_H2048.yaml\": \"r6zen9k43m5oo58\",\n        \"task_ablations/token_classification_R_50.yaml\": \"o4p9lki505r0mef\",\n        \"task_ablations/multilabel_classification_R_50.yaml\": \"hbspp3jv3u8h3bc\",\n        \"task_ablations/masked_lm_R_50_L1_H2048.yaml\": \"ldzrk6vem4mg6bl\",\n\n        # Width Ablations\n        \"width_ablations/bicaptioning_R_50_L1_H512.yaml\": \"o9fr69jjqfn8a65\",\n        \"width_ablations/bicaptioning_R_50_L1_H768.yaml\": \"1zxglqrrbfufv9d\",\n        \"width_ablations/bicaptioning_R_50_L1_H1024.yaml\": \"pdat4tvhnqxel64\",\n        \"width_ablations/bicaptioning_R_50_L1_H2048.yaml\": \"mbeeso8wyieq8wy\",\n\n        # Depth Ablations\n        \"depth_ablations/bicaptioning_R_50_L1_H1024.yaml\": \"pdat4tvhnqxel64\",\n        \"depth_ablations/bicaptioning_R_50_L2_H1024.yaml\": \"ft1vtt4okirzjgo\",\n        \"depth_ablations/bicaptioning_R_50_L3_H1024.yaml\": \"5ldo1rcsnrshmjr\",\n        \"depth_ablations/bicaptioning_R_50_L4_H1024.yaml\": \"zgiit2wcluuq3xh\",\n\n        # Backbone Ablations\n        \"backbone_ablations/bicaptioning_R_50_L1_H1024.yaml\": \"pdat4tvhnqxel64\",\n        \"backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml\": \"5o198ux709r6376\",\n        \"backbone_ablations/bicaptioning_R_101_L1_H1024.yaml\": \"bb74jubt68cpn80\",\n    }\n\n\ndef get(config_path: str, pretrained: bool = False):\n    r\"\"\"\n    Get a model specified by relative path under Detectron2's official\n    ``configs/`` directory.\n\n    Args:\n        config_path: Name of config file relative to ``configs/`` directory\n            under project root. (E.g. ``width_ablations/bicaptioning_R_50_L1_H2048.yaml``)\n        pretrained: If ``True``, will initialize the model with the pretrained\n            weights. If ``False``, the weights will be initialized randomly.\n    \"\"\"\n\n    # Get the original path to config file (shipped with inside the package).\n    _pkg_config_path = pkg_resources.resource_filename(\n        \"virtex.model_zoo\", os.path.join(\"configs\", config_path)\n    )\n    if not os.path.exists(_pkg_config_path):\n        raise RuntimeError(\"{} not available in Model Zoo!\".format(config_path))\n\n    _C = Config(_pkg_config_path)\n    model = PretrainingModelFactory.from_config(_C)\n\n    if pretrained:\n        # Get URL for the checkpoint for this config path.\n        if config_path in _ModelZooUrls.CONFIG_PATH_TO_DB_ID:\n\n            dropbox_id = _ModelZooUrls.CONFIG_PATH_TO_DB_ID[config_path]\n            filename = os.path.basename(config_path).replace(\".yaml\", \".pth\")\n\n            checkpoint_url = f\"{_ModelZooUrls.URL_PREFIX}/{dropbox_id}/{filename}?dl=1\"\n        else:\n            raise RuntimeError(\"{} not available in Model Zoo!\".format(config_path))\n\n        # Download the pretrained model weights and save with a sensible name.\n        # This will be downloaded only if it does not exist.\n        checkpoint_path = download(\n            checkpoint_url,\n            dir=os.path.expanduser(\"~/.torch/virtex_cache\"),\n            filename=os.path.basename(config_path).replace(\".yaml\", \".pth\")\n        )\n        CheckpointManager(model=model).load(checkpoint_path)\n\n    return model\n"
  },
  {
    "path": "virtex/models/__init__.py",
    "content": "from .captioning import (\n    ForwardCaptioningModel,\n    BidirectionalCaptioningModel,\n    VirTexModel\n)\nfrom .masked_lm import MaskedLMModel\nfrom .classification import (\n    MultiLabelClassificationModel,\n    TokenClassificationModel,\n)\n\n\n__all__ = [\n    \"VirTexModel\",\n    \"BidirectionalCaptioningModel\",\n    \"ForwardCaptioningModel\",\n    \"MaskedLMModel\",\n    \"MultiLabelClassificationModel\",\n    \"TokenClassificationModel\",\n]\n"
  },
  {
    "path": "virtex/models/captioning.py",
    "content": "import copy\nimport functools\nfrom typing import Any, Dict\n\nimport torch\nfrom torch import nn\n\nfrom virtex.data.tokenizers import SentencePieceBPETokenizer\nfrom virtex.modules.textual_heads import TextualHead\nfrom virtex.modules.visual_backbones import VisualBackbone\n\n\nclass CaptioningModel(nn.Module):\n    r\"\"\"\n    A model to perform image captioning (in both forward and backward directions\n    independently, only in forward direction). It is composed of a\n    :class:`~virtex.modules.visual_backbones.VisualBackbone` and a\n    :class:`~virtex.modules.textual_heads.TextualHead` on top of it.\n\n    During training, it maximizes the likelihood of ground truth caption\n    conditioned on image features. During inference, it predicts a caption for\n    an input image through beam search decoding.\n\n    Args:\n        visual: A :class:`~virtex.modules.visual_backbones.VisualBackbone` which\n            computes visual features from an input image.\n        textual: A :class:`~virtex.modules.textual_heads.TextualHead` which\n            makes final predictions conditioned on visual features.\n        sos_index: The index of the start token (``[SOS]``) in vocabulary.\n        eos_index: The index of the end token (``[EOS]``) in vocabulary.\n        caption_backward: Whether to *also* perform captioning in backward\n            direction. Default is ``False`` -- only forward captioning is\n            performed. When ``True``, a clone of textual head is created, which\n            does not share weights with \"forward\" model except input/output embeddings.\n        decoder: A :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch`\n            or :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`\n            object for decoding captions during inference (unused during training).\n    \"\"\"\n\n    def __init__(\n        self,\n        visual: VisualBackbone,\n        textual: TextualHead,\n        caption_backward: bool = False,\n        sos_index: int = 1,\n        eos_index: int = 2,\n        decoder: Any = None,\n    ):\n        super().__init__()\n        self.visual = visual\n        self.textual = textual\n        self.padding_idx = self.textual.padding_idx\n        self.caption_backward = caption_backward\n\n        # Clone the textual module for backward direction if doing captioning\n        # in both directions (separately).\n        if self.caption_backward:\n            self.backward_textual = copy.deepcopy(self.textual)\n\n            # Share weights for visual projection, and input/output embeddings.\n            self.backward_textual.visual_projection = self.textual.visual_projection\n            self.backward_textual.embedding = self.textual.embedding\n            self.backward_textual.output = self.textual.output\n\n        # These boundary indices are needed for beam search.\n        self.sos_index = sos_index\n        self.eos_index = eos_index\n        self.decoder = decoder\n        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx)\n\n    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:\n        r\"\"\"\n        Given a batch of images and captions, compute log likelihood loss per\n        caption token during training. During inference (with images), predict\n        a caption through either beam search decoding or nucleus sampling.\n\n        Args:\n            batch: A batch of images and (optionally) ground truth caption tokens.\n                Possible set of keys: ``{\"image_id\", \"image\", \"caption_tokens\",\n                \"noitpac_tokens\", \"caption_lengths\"}``.\n\n        Returns:\n            A dict with the following structure, containing loss for optimization,\n            loss components to log directly to tensorboard, and optionally\n            predictions.\n\n            .. code-block::\n\n                {\n                    \"loss\": torch.Tensor,\n                    \"loss_components\": {\n                        \"captioning_forward\": torch.Tensor,\n                        \"captioning_backward\": torch.Tensor, (optional)\n                    },\n                    \"predictions\": torch.Tensor\n                }\n        \"\"\"\n\n        # shape: (batch_size, channels, height, width)\n        visual_features = self.visual(batch[\"image\"])\n        batch_size = visual_features.size(0)\n\n        if \"caption_tokens\" in batch:\n            caption_tokens = batch[\"caption_tokens\"]\n            caption_lengths = batch[\"caption_lengths\"]\n\n            # shape: (batch_size, max_caption_length, vocab_size)\n            output_logits = self.textual(\n                visual_features, caption_tokens, caption_lengths\n            )\n            loss = self.loss(\n                output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size),\n                caption_tokens[:, 1:].contiguous().view(-1),\n            )\n            output_dict: Dict[str, Any] = {\n                \"loss\": loss,\n                # Single scalar per batch for logging in training script.\n                \"loss_components\": {\"captioning_forward\": loss.clone().detach()},\n            }\n            # Do captioning in backward direction if specified.\n            if self.caption_backward:\n                backward_caption_tokens = batch[\"noitpac_tokens\"]\n\n                backward_output_logits = self.backward_textual(\n                    visual_features, backward_caption_tokens, caption_lengths\n                )\n                backward_loss = self.loss(\n                    backward_output_logits[:, :-1]\n                    .contiguous()\n                    .view(-1, self.textual.vocab_size),\n                    backward_caption_tokens[:, 1:].contiguous().view(-1),\n                )\n                output_dict[\"loss\"] += backward_loss\n\n                # Single scalar per batch for logging in training script.\n                output_dict[\"loss_components\"].update(\n                    captioning_backward=backward_loss.clone().detach()\n                )\n\n            if not self.training:\n                # During validation (while pretraining), get best prediction\n                # at every timestep.\n                output_dict[\"predictions\"] = torch.argmax(output_logits, dim=-1)\n        else:\n            if self.decoder is None:\n                raise ValueError(\"Decoder for predicting captions is missing!\")\n\n            # During inference, get beam search predictions for forward\n            # model. Predictions from forward transformer will be shifted\n            # right by one timestep.\n            start_predictions = visual_features.new_full(\n                (batch_size,), self.sos_index\n            ).long()\n            # Add image features as a default argument to match callable\n            # signature accepted by beam search class (partial captions only).\n            decoding_step = functools.partial(self.decoding_step, visual_features)\n\n            predicted_caption, _ = self.decoder.search(\n                start_predictions, decoding_step\n            )\n            output_dict = {\"predictions\": predicted_caption}\n\n        return output_dict\n\n    def decoding_step(\n        self, visual_features: torch.Tensor, partial_captions: torch.Tensor\n    ) -> torch.Tensor:\n        r\"\"\"\n        Given visual features and a batch of (assumed) partial captions, predict\n        the logits over output vocabulary tokens for next timestep. This method\n        is used by :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch`\n        and :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`.\n\n        .. note::\n\n            For nucleus sampling, ``beam_size`` will always be 1 (not relevant).\n\n        Args:\n            projected_visual_features: A tensor of shape ``(batch_size, ...,\n                textual_feature_size)`` with visual features already projected to\n                ``textual_feature_size``.\n            partial_captions: A tensor of shape ``(batch_size * beam_size, timesteps)``\n                containing tokens predicted so far -- one for each beam. We need all\n                prior predictions because our model is auto-regressive.\n\n        Returns:\n            A tensor of shape ``(batch_size * beam_size, vocab_size)`` -- logits\n            over output vocabulary tokens for next timestep.\n        \"\"\"\n\n        # Expand and repeat image features while doing beam search.\n        batch_size, channels, height, width = visual_features.size()\n        beam_size = int(partial_captions.size(0) / batch_size)\n        if beam_size > 1:\n            # shape: (batch_size * beam_size, channels, height, width)\n            visual_features = visual_features.unsqueeze(1).repeat(1, beam_size, 1, 1, 1)\n            visual_features = visual_features.view(\n                batch_size * beam_size, channels, height, width\n            )\n\n        # Provide caption lengths as current length (irrespective of predicted\n        # EOS/padding tokens). shape: (batch_size, )\n        caption_lengths = torch.ones_like(partial_captions)\n        if len(caption_lengths.size()) == 2:\n            caption_lengths = caption_lengths.sum(1)\n        else:\n            # Add a timestep. shape: (batch_size, 1)\n            partial_captions = partial_captions.unsqueeze(1)\n\n        # shape: (batch_size * beam_size, partial_caption_length, vocab_size)\n        logits = self.textual(visual_features, partial_captions, caption_lengths)\n        # Return logits from the last timestep.\n        return logits[:, -1, :]\n\n    def log_predictions(\n        self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer\n    ) -> str:\n\n        self.eval()\n        with torch.no_grad():\n            predictions = self.forward(batch)[\"predictions\"]\n        self.train()\n\n        predictions_str = \"\"\n        for tokens, preds in zip(batch[\"caption_tokens\"], predictions):\n            predictions_str += f\"\"\"\n                Caption tokens : {\" \".join(tokens.tolist())}\n                Predictions (f): {\" \".join(preds.tolist())}\n\n                \"\"\"\n        return predictions_str\n\n\nclass ForwardCaptioningModel(CaptioningModel):\n    r\"\"\"\n    Convenient extension of :class:`~virtex.models.captioning.CaptioningModel`\n    for better readability: this passes ``caption_backward=False`` to super class.\n    \"\"\"\n\n    def __init__(\n        self,\n        visual: VisualBackbone,\n        textual: TextualHead,\n        sos_index: int = 1,\n        eos_index: int = 2,\n        decoder: Any = None,\n    ):\n        super().__init__(\n            visual,\n            textual,\n            sos_index=sos_index,\n            eos_index=eos_index,\n            caption_backward=False,\n            decoder=decoder,\n        )\n\n\nclass BidirectionalCaptioningModel(CaptioningModel):\n    r\"\"\"\n    Convenient extension of :class:`~virtex.models.captioning.CaptioningModel`\n    for better readability: this passes ``caption_backward=True`` to super class.\n    \"\"\"\n\n    def __init__(\n        self,\n        visual: VisualBackbone,\n        textual: TextualHead,\n        sos_index: int = 1,\n        eos_index: int = 2,\n        decoder: Any = None,\n    ):\n        super().__init__(\n            visual,\n            textual,\n            sos_index=sos_index,\n            eos_index=eos_index,\n            caption_backward=True,\n            decoder=decoder,\n        )\n\n\n# Convenient handle for our main model.\nVirTexModel = BidirectionalCaptioningModel\n"
  },
  {
    "path": "virtex/models/classification.py",
    "content": "from typing import Any, Dict, List\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom virtex.data.tokenizers import SentencePieceBPETokenizer\nfrom virtex.modules.textual_heads import TextualHead\nfrom virtex.modules.visual_backbones import VisualBackbone\n\n\nclass ClassificationModel(nn.Module):\n    r\"\"\"\n    A model to perform classification (generally, with multiple targets). It is\n    composed of a :class:`~virtex.modules.visual_backbones.VisualBackbone` and a\n    :class:`~virtex.modules.textual_heads.TextualHead` on top of it.\n\n    .. note::\n\n        As with currently available textual heads, only one textual head is\n        supported here: :class:`~virtex.modules.textual_heads.LinearTextualHead`.\n\n    During training, it minimizes the KL-divergence loss with a K-hot vector,\n    with values ``1/K``, where K are the number of unique labels to classify.\n\n    Args:\n        visual: A :class:`~virtex.modules.visual_backbones.VisualBackbone` which\n            computes visual features from an input image.\n        textual: A :class:`~virtex.modules.textual_heads.TextualHead` which\n            makes final predictions conditioned on visual features.\n        ignore_indices: Ignore a set of token indices while computing KL-divergence\n            loss. These are special tokens such as ``[SOS]``, ``[EOS]`` etc.\n    \"\"\"\n\n    def __init__(\n        self, visual: VisualBackbone, textual: TextualHead, ignore_indices: List[int]\n    ):\n        super().__init__()\n        self.visual = visual\n        self.textual = textual\n        self.ignore_indices = ignore_indices\n\n    def forward(self, batch: Dict[str, torch.Tensor]):\n        r\"\"\"\n        Given a batch of images and set of labels, perform classification with\n        multiple targets by minimizing a KL-divergence loss.\n\n        Args:\n            batch: A batch of images and labels. Possible set of keys:\n                ``{\"image_id\", \"image\", \"labels\"}``\n\n        Returns:\n            A dict with the following structure, containing loss for optimization,\n            loss components to log directly to tensorboard, and optionally\n            predictions.\n\n            .. code-block::\n\n                {\n                    \"loss\": torch.Tensor,\n                    \"loss_components\": {\n                        \"classification\": torch.Tensor,\n                    },\n                    \"predictions\": torch.Tensor\n                }\n        \"\"\"\n\n        # shape: (batch_size, visual_feature_size, ...)\n        visual_features = self.visual(batch[\"image\"])\n        batch_size = visual_features.size(0)\n\n        # Get logits and further log-probabilities.\n        # shape: (batch_size, vocab_size)\n        logits = self.textual(visual_features)\n        logprobs = F.log_softmax(logits, dim=1)\n\n        # Average log-probs per unique token in associated caption to compute\n        # loss. This is simply cross-entropy with target-vector as a K-hot\n        # vector. Do in a for-loop, there isn't a straightforward vectorized way.\n        loss = torch.tensor(0.0, device=logprobs.device)\n\n        for index in range(batch_size):\n            # Get unique labels for particular instance.\n            unique_labels = batch[\"labels\"][index].unique()\n\n            # Ignore indices of special tokens such as [SOS], [EOS] etc. and\n            # any other token specified.\n            unique_labels = [l for l in unique_labels if l not in self.ignore_indices]\n            # Get log-probabilities corresponding to these tokens.\n            instance_logprobs = logprobs[index, unique_labels].mean()\n\n            # Accumulate negative log-probability for this instance in loss.\n            loss = loss - instance_logprobs\n\n        # Average loss across instances.\n        output_dict: Dict[str, Any] = {\"loss\": loss / batch_size}\n\n        # Single scalar per batch for logging to tensorboard in training script.\n        output_dict[\"loss_components\"] = {\n            \"classification\": loss.clone().detach() / batch_size\n        }\n        # Return top-10 tokens according to log-probabilities during validation.\n        # Useful for logging.\n        if not self.training:\n            top_logprobs, top_tokens = logprobs.topk(k=10, dim=1)\n            output_dict[\"predictions\"] = top_tokens\n\n        return output_dict\n\n\nclass TokenClassificationModel(ClassificationModel):\n    r\"\"\"\n    Convenient extension of :class:`~virtex.models.classification.ClassificationModel`\n    for better readability (this only modifies the tensorboard logging logic).\n\n    Ground truth targets here are a set of unique caption tokens (ignoring the\n    special tokens like ``[SOS]``, ``[EOS]`` etc.).\n    \"\"\"\n\n    def log_predictions(\n        self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer\n    ) -> str:\n\n        self.eval()\n        with torch.no_grad():\n            predictions = self.forward(batch)[\"predictions\"]\n        self.train()\n\n        predictions_str = \"\"\n        for tokens, preds in zip(batch[\"caption_tokens\"], predictions):\n            # Predictions here are individual tokens, and do not have any order\n            # like captions, so decode them separately so we don't strip off\n            # metaspace character and special tokens if any.\n            preds = [tokenizer.id_to_token(p) for p in preds.tolist()]\n            predictions_str += f\"\"\"\n                Caption tokens : {tokenizer.decode(tokens.tolist())}\n                Predictions (f): {\" \".join(preds)}\n\n                \"\"\"\n        return predictions_str\n\n\nclass MultiLabelClassificationModel(ClassificationModel):\n    r\"\"\"\n    Convenient extension of :class:`~virtex.models.classification.ClassificationModel`\n    for better readability (this only modifies the tensorboard logging logic).\n\n    Ground truth targets here are a set of unique instances in images (ignoring\n    the special background token, category id = 0 in COCO).\n    \"\"\"\n\n    def log_predictions(\n        self,\n        batch: Dict[str, torch.Tensor],\n        tokenizer: SentencePieceBPETokenizer = None,\n    ) -> str:\n        # We accept `tokenizer` for having consistent API but don't use it here.\n        self.eval()\n        with torch.no_grad():\n            predictions = self.forward(batch)[\"predictions\"]\n        self.train()\n\n        predictions_str = \"\"\n        for tokens, preds in zip(batch[\"caption_tokens\"], predictions):\n            # Predictions here are COCO category IDs, let them be as is.\n            # Sorted ground truth, remove background tokens.\n            tokens = sorted([t for t in tokens.tolist() if t != 0])\n            preds = sorted(preds.tolist()[: len(tokens)])\n            predictions_str += f\"\"\"\n                COCO Instance IDs (GT)   : {tokens}\n                COCO Instance IDs (Pred) : {preds}\n\n                \"\"\"\n        return predictions_str\n"
  },
  {
    "path": "virtex/models/masked_lm.py",
    "content": "from typing import Any, Dict\n\nimport torch\nfrom torch import nn\n\nfrom virtex.data.tokenizers import SentencePieceBPETokenizer\nfrom virtex.modules.textual_heads import TextualHead\nfrom virtex.modules.visual_backbones import VisualBackbone\n\n\nclass MaskedLMModel(nn.Module):\n    r\"\"\"\n    A model to perform BERT-like masked language modeling. It is composed of a\n    :class:`~virtex.modules.visual_backbones.VisualBackbone` and a\n    :class:`~virtex.modules.textual_heads.TextualHead` on top of it.\n\n    During training, the model received caption tokens with certain tokens\n    replaced by ``[MASK]`` token, and it predicts these masked tokens based on\n    surrounding context.\n\n    Args:\n        visual: A :class:`~virtex.modules.visual_backbones.VisualBackbone` which\n            computes visual features from an input image.\n        textual: A :class:`~virtex.modules.textual_heads.TextualHead` which\n            makes final predictions conditioned on visual features.\n    \"\"\"\n\n    def __init__(self, visual: VisualBackbone, textual: TextualHead):\n        super().__init__()\n        self.visual = visual\n        self.textual = textual\n        self.padding_idx = self.textual.padding_idx\n        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx)\n\n    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]:\n        r\"\"\"\n        Given a batch of images and captions with certain masked tokens,\n        predict the tokens at masked positions.\n\n        Args:\n            batch: A batch of images, ground truth caption tokens and masked labels.\n                Possible set of keys: ``{\"image_id\", \"image\", \"caption_tokens\",\n                \"masked_labels\", \"caption_lengths\"}``.\n\n        Returns:\n            A dict with the following structure, containing loss for optimization,\n            loss components to log directly to tensorboard, and optionally\n            predictions.\n\n            .. code-block::\n\n                {\n                    \"loss\": torch.Tensor,\n                    \"loss_components\": {\"masked_lm\": torch.Tensor},\n                    \"predictions\": torch.Tensor\n                }\n        \"\"\"\n\n        # shape: (batch_size, channels, height, width)\n        visual_features = self.visual(batch[\"image\"])\n\n        caption_tokens = batch[\"caption_tokens\"]\n        caption_lengths = batch[\"caption_lengths\"]\n        masked_labels = batch[\"masked_labels\"]\n\n        # shape: (batch_size, num_caption_tokens, vocab_size)\n        output_logits = self.textual(visual_features, caption_tokens, caption_lengths)\n        output_dict: Dict[str, Any] = {\n            \"loss\": self.loss(\n                output_logits.view(-1, output_logits.size(-1)), masked_labels.view(-1)\n            )\n        }\n        # Single scalar per batch for logging in training script.\n        output_dict[\"loss_components\"] = {\n            \"masked_lm\": output_dict[\"loss\"].clone().detach()\n        }\n        # During evaluation, get predictions from logits. Useful for logging.\n        # Only the predictions at [MASK]ed positions are relevant.\n        if not self.training:\n            predictions = torch.argmax(output_logits, dim=-1)\n            redundant_positions = masked_labels == self.padding_idx\n            predictions[redundant_positions] = self.padding_idx\n\n            output_dict[\"predictions\"] = predictions\n\n        return output_dict\n\n    def log_predictions(\n        self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer\n    ) -> str:\n\n        self.eval()\n        with torch.no_grad():\n            predictions = self.forward(batch)[\"predictions\"]\n        self.train()\n\n        predictions_str = \"\"\n        for tokens, labels, preds in zip(\n            batch[\"caption_tokens\"], batch[\"masked_labels\"], predictions\n        ):\n            predictions_str += f\"\"\"\n                Caption tokens : {tokenizer.decode(tokens.tolist())}\n                Masked Labels  : {tokenizer.decode(labels.tolist())}\n                Predictions    : {tokenizer.decode(preds.tolist())}\n                \"\"\"\n        return predictions_str\n"
  },
  {
    "path": "virtex/modules/embedding.py",
    "content": "import functools\n\nimport torch\nfrom torch import nn\n\n\nclass WordAndPositionalEmbedding(nn.Module):\n    r\"\"\"\n    A :class:`~torch.nn.Module` for learned word embeddings and position\n    embeddings for input tokens. Each token is mapped to a fixed dimensional\n    word embedding; and corresponding positional embedding based on its index.\n    These are summed together followed by layer normalization and an optional\n    dropout.\n\n    Args:\n        vocab_size: Size of token vocabulary.\n        hidden_size: Size of token embedding vectors.\n        dropout: Probability for final dropout applied after layer normalization.\n        max_caption_length: Maximum length of input captions; this is used to create a\n            fixed positional embedding lookup table.\n        padding_idx: Token index of ``[PAD]`` token, word embedding for these tokens\n            will be a vector of zeroes (and not trainable).\n    \"\"\"\n    def __init__(\n        self,\n        vocab_size: int,\n        hidden_size: int,\n        dropout: float = 0.0,\n        max_caption_length: int = 30,\n        padding_idx: int = 0,\n    ):\n        super().__init__()\n        self.vocab_size = vocab_size\n        self.padding_idx = padding_idx\n\n        self.words = nn.Embedding(vocab_size, hidden_size, padding_idx=padding_idx)\n\n        # We provide no \"padding index\" for positional embeddings. We zero out\n        # the positional embeddings of padded positions as a post-processing.\n        self.positions = nn.Embedding(max_caption_length, hidden_size)\n        self.layer_norm = nn.LayerNorm(\n            hidden_size, eps=1e-8, elementwise_affine=True\n        )\n        self.dropout = nn.Dropout(p=dropout)\n\n    def forward(self, tokens: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Get combined word and positional embeddings for input tokens.\n\n        Args:\n            tokens: A tensor of shape ``(batch_size, max_caption_length)``\n                containing a batch of caption tokens, values in ``[0, vocab_size)``.\n\n        Returns:\n            A tensor of shape ``(batch_size, max_caption_length, hidden_size)``\n            containing corresponding token embeddings.\n        \"\"\"\n        position_indices = self._create_position_indices(tokens)\n\n        # shape: (batch_size, max_caption_length, hidden_size)\n        word_embeddings = self.words(tokens)\n        position_embeddings = self.positions(position_indices)\n\n        # shape: (batch_size, max_caption_length, hidden_size)\n        embeddings = self.layer_norm(word_embeddings + position_embeddings)\n        embeddings = self.dropout(embeddings)\n\n        # Zero-out embeddings for positions which have padding tokens.\n        # shape: (batch_size, max_caption_length, 1)\n        token_mask = (tokens != self.padding_idx).unsqueeze(-1)\n\n        # shape: (batch_size, max_caption_length, hidden_size)\n        embeddings = embeddings * token_mask.type(embeddings.dtype)\n        return embeddings\n\n    @functools.lru_cache(maxsize=128)\n    def _create_position_indices(self, tokens: torch.Tensor):\n\n        # Create position indices of the same size as token indices.\n        batch_size, max_caption_length = tokens.size()\n        positions = torch.arange(\n            max_caption_length, dtype=tokens.dtype, device=tokens.device\n        )\n        # shape: (batch_size, max_caption_length)\n        positions = positions.unsqueeze(0).expand(batch_size, max_caption_length)\n        return positions\n"
  },
  {
    "path": "virtex/modules/textual_heads.py",
    "content": "r\"\"\"\nA textual head accepts visual features from the visual backbone, and performs\ntask specific modeling (captioning, classification etc.) to predict an output\ndistribution over vocabulary tokens for one or multiple time-steps in the batch.\n\"\"\"\nimport functools\n\nimport torch\nfrom torch import nn\nfrom typing import Optional\n\nfrom virtex.modules.embedding import WordAndPositionalEmbedding\n\n\nclass TextualHead(nn.Module):\n    r\"\"\"\n    Base class for all textual heads. All child classes can simply inherit\n    from :class:`~torch.nn.Module`, however this is kept here for uniform\n    type annotations.\n\n    Args:\n        visual_feature_size: Size (number of channels) of the input features\n            from the visual backbone.\n        vocab_size: Number of tokens in the output vocabulary.\n        hidden_size: Size of the token embedding vectors, or hidden state vector\n            of the language model.\n    \"\"\"\n\n    def __init__(self, visual_feature_size: int, vocab_size: int, hidden_size: int):\n        super().__init__()\n        self.visual_feature_size = visual_feature_size\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n\n    @property\n    def textual_feature_size(self):\n        r\"\"\"\n        Size of the last dimension of output right before the output linear\n        layer (which predicts a distribution over vocabulary tokens). This is\n        typically same as :attr:`hidden_size` for most modules. This property\n        is used to add more modules on top of this.\n        \"\"\"\n        return self.hidden_size\n\n\nclass LinearTextualHead(TextualHead):\n    r\"\"\"\n    A textual head containing a single linear layer projecting from the visual\n    feature size to the output vocabulary size.\n\n    Args:\n        visual_feature_size: Size (number of channels) of the input features from\n            the visual backbone.\n        vocab_size: Number of tokens in the output vocabulary.\n    \"\"\"\n\n    def __init__(self, visual_feature_size: int, vocab_size: int, **kwargs):\n        # For API consistency.\n        hidden_size = visual_feature_size\n        super().__init__(visual_feature_size, vocab_size, hidden_size)\n        self.output = nn.Linear(visual_feature_size, vocab_size)\n\n    def forward(\n        self,\n        visual_features: torch.Tensor,\n        caption_tokens: Optional[torch.Tensor] = None,\n        caption_lengths: Optional[torch.Tensor] = None,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Project visual features directly to predict a distribution over\n        vocabulary tokens through a single linear layer. This textual head\n        ignores arguments ``caption_tokens`` and ``caption_lengths``, they\n        are here for API consistency.\n\n        Args:\n            visual_features: A tensor of shape ``(batch_size, channels, height,\n                width)`` containing features from visual backbone.\n\n        Returns:\n            A tensor of shape ``(batch_size, vocab_size)`` containing output\n            vocabulary logits.\n        \"\"\"\n\n        # Convert to NHWC and project visual features to textual feature size.\n        batch_size, channels, _, _ = visual_features.size()\n        visual_features = visual_features.view(batch_size, channels, -1)\n        visual_features = visual_features.permute(0, 2, 1)\n\n        # Perform global average pooling of visual features.\n        # shape: (batch_size, channels)\n        visual_features = visual_features.mean(dim=1)\n\n        # shape: (batch_size, max_caption_length, vocab_size)\n        output_logits = self.output(visual_features)\n        return output_logits\n\n\nclass TransformerDecoderTextualHead(TextualHead):\n    r\"\"\"\n    A textual head composed of four main modules: (1) input projection (linear\n    layer) for visual features to match size with textual features, (2) word\n    and positional embedding for input captions, (3) a unidirectional transformer\n    decoder, and (4) and output projection (linear layer) to predict a\n    distribution over vocabulary tokens. The word embedding weights are tied\n    with output projection; the latter still has its own learnable bias.\n\n    .. note::\n\n        For the \"bicaptioning\" pretraining task, our *textual head* (as defined\n        in the paper) must have two transformer decoders: one each to decode\n        caption in either direction. This class however will always have one\n        transformer per object.\n\n        Refer :class:`~virtex.models.captioning.BidirectionalCaptioningModel`\n        source to understand how an object of this class is cloned, along with\n        tying embedding and output weights, for bicaptioning.\n\n        Hence, while there are *two objects* of this class, it is pragmatically\n        a *single* textual head as a whole, according to the terminology used\n        in paper.\n\n    Args:\n        visual_feature_size: Size (number of channels) of the input features from\n            the visual backbone.\n        vocab_size: Number of tokens in the output vocabulary.\n        hidden_size: Size of the token embedding vectors, or hidden state vector of\n            the language model.\n        num_layers: Number of layers in the transformer.\n        attention_heads: Number of attention heads in the transformer.\n        feedforward_size: Size of feedforward layers in the transformer.\n        dropout: Dropout probability for transformer (applied after layernorm).\n        norm_first: Whether to apply normalization before or after attention/FF\n            layers. The former type are called pre-norm variants (like GPT-2) and\n            latter are post-norm variants (like BERT). Default is post-norm.\n        mask_future_positions: Whether to mask future positions for self-attention\n            over caption tokens. This must be ``True`` for captioning (and\n            bicaptioning) tasks to prevent the language model from cheating, and\n            ``False`` for masked language modeling, as the self-attention should\n            consider all tokens.\n        max_caption_length: Maximum length of input captions; this is used to\n            create a fixed positional embedding lookup table.\n        padding_idx: Token index of ``[PAD]`` token, word embedding for these\n            tokens will be a vector of zeroes (and not trainable).\n    \"\"\"\n\n    def __init__(\n        self,\n        visual_feature_size: int,\n        vocab_size: int,\n        hidden_size: int,\n        num_layers: int,\n        attention_heads: int,\n        feedforward_size: int,\n        dropout: float = 0.1,\n        norm_first: bool = False,\n        mask_future_positions: bool = True,\n        max_caption_length: int = 30,\n        padding_idx: int = 0,\n    ):\n        super().__init__(visual_feature_size, vocab_size, hidden_size)\n        self.num_layers = num_layers\n        self.attention_heads = attention_heads\n        self.feedforward_size = feedforward_size\n        self.dropout = dropout\n        self.mask_future_positions = mask_future_positions\n        self.padding_idx = padding_idx\n\n        self.visual_projection = nn.Linear(\n            visual_feature_size, self.textual_feature_size\n        )\n        self.embedding = WordAndPositionalEmbedding(\n            self.vocab_size,\n            self.textual_feature_size,\n            dropout=dropout,\n            max_caption_length=max_caption_length,\n            padding_idx=padding_idx,\n        )\n\n        # Initialize a transformer with given transformer class (for example\n        # nn.TransformerEncoder and nn.TransformerEncoderLayer).\n        self.transformer = nn.TransformerDecoder(\n            nn.TransformerDecoderLayer(\n                self.textual_feature_size,\n                self.attention_heads,\n                dim_feedforward=self.feedforward_size,\n                dropout=dropout,\n                activation=\"gelu\",\n                batch_first=True,\n                norm_first=norm_first,\n            ),\n            num_layers=self.num_layers,\n            # Add final layer norm for pre-norm transformers.\n            norm=nn.LayerNorm(self.hidden_size) if norm_first else None,\n        )\n        self.apply(self._init_weights)\n\n        # Create an output linear layer and tie the input and output word\n        # embeddings to reduce parameters.\n        self.output = nn.Linear(self.textual_feature_size, vocab_size)\n        self.output.weight = self.embedding.words.weight\n\n    @staticmethod\n    def _init_weights(module):\n        r\"\"\"Initialize weights like BERT - N(0.0, 0.02), bias = 0.\"\"\"\n\n        if isinstance(module, nn.Linear):\n            module.weight.data.normal_(mean=0.0, std=0.02)\n        elif isinstance(module, nn.MultiheadAttention):\n            module.in_proj_weight.data.normal_(mean=0.0, std=0.02)\n            module.out_proj.weight.data.normal_(mean=0.0, std=0.02)\n        elif isinstance(module, nn.Embedding):\n            module.weight.data.normal_(mean=0.0, std=0.02)\n            if module.padding_idx is not None:\n                module.weight.data[module.padding_idx].zero_()\n\n    def forward(\n        self,\n        visual_features: torch.Tensor,\n        caption_tokens: torch.Tensor,\n        caption_lengths: torch.Tensor,\n    ) -> torch.Tensor:\n        r\"\"\"\n        Given (projected) visual features from visual backbone and caption\n        tokens, predict the output logits for next time-step.\n\n        Args:\n            visual_features: A tensor of shape ``(batch_size, channels, height,\n                width)`` containing features from visual backbone.\n            caption_tokens: A tensor of shape ``(batch_size, max_caption_length)``\n                of caption tokens padded to the right by ``padding_idx``.\n            caption_lengths: A tensor of shape ``(batch_size, )`` containing\n                lengths of caption tokens in the batch.\n\n        Returns:\n            A tensor of shape ``(batch_size, max_caption_length, vocab_size)``\n            containing output vocabulary logits for each time-step.\n        \"\"\"\n\n        # Convert to NHWC and project visual features to textual feature size.\n        batch_size, channels, height, width = visual_features.size()\n        visual_features = visual_features.view(batch_size, channels, -1)\n        visual_features = visual_features.permute(0, 2, 1)\n\n        # shape: (batch_size, height * width, textual_feature_size)\n        projected_visual_features = self.visual_projection(visual_features)\n        # Now visual and textual features are of same size.\n\n        # Note that `max_caption_length` here may be less than the\n        # `max_caption_length` passed in `__init__`, but it does not matter.\n        batch_size, max_caption_length = caption_tokens.size()\n\n        # Create a mask based on caption lengths, shape: (batch_size, )\n        # Form a binary mask: it is True for padding positions.\n        # These positions will be ignored for multi-headed attention.\n        ones = torch.ones_like(caption_tokens)\n        caption_mask = caption_lengths.unsqueeze(1) < ones.cumsum(dim=1)\n\n        # shape: (batch_size, max_caption_length, textual_feature_size)\n        caption_embeddings = self.embedding(caption_tokens)\n\n        if self.mask_future_positions:\n            # An additive mask for masking the future (one direction).\n            future_mask = self.make_future_mask(\n                max_caption_length, caption_embeddings.dtype, caption_embeddings.device\n            )\n        else:\n            future_mask = None\n\n        # shape: (batch_size, max_caption_length, hidden_size)\n        textual_features = self.transformer(\n            caption_embeddings,\n            projected_visual_features,\n            tgt_mask=future_mask,\n            tgt_key_padding_mask=caption_mask,\n        )\n        # shape: (batch_size, max_caption_length, vocab_size)\n        output_logits = self.output(textual_features)\n        return output_logits\n\n    @staticmethod\n    @functools.cache\n    def make_future_mask(\n        size: int, dtype: torch.dtype, device: torch.device\n    ) -> torch.Tensor:\n        \"\"\"\n        Generate a mask for \"future\" positions. Masked positions will be negative\n        infinity. This mask is critical for casual language modeling.\n        \"\"\"\n        return torch.triu(\n            torch.full((size, size), float(\"-inf\"), dtype=dtype, device=device),\n            diagonal=1,\n        )\n"
  },
  {
    "path": "virtex/modules/visual_backbones.py",
    "content": "from typing import Any, Dict\n\nimport torch\nfrom torch import nn\nimport torchvision\n\n\nclass VisualBackbone(nn.Module):\n    r\"\"\"\n    Base class for all visual backbones. All child classes can simply inherit\n    from :class:`~torch.nn.Module`, however this is kept here for uniform\n    type annotations.\n    \"\"\"\n\n    def __init__(self, visual_feature_size: int):\n        super().__init__()\n        self.visual_feature_size = visual_feature_size\n\n\nclass TorchvisionVisualBackbone(VisualBackbone):\n    r\"\"\"\n    A visual backbone from `Torchvision model zoo\n    <https://pytorch.org/docs/stable/torchvision/models.html>`_. Any model can\n    be specified using corresponding method name from the model zoo.\n\n    Args:\n        name: Name of the model from Torchvision model zoo.\n        visual_feature_size: Size of the channel dimension of output visual\n            features from forward pass.\n        pretrained: Whether to load ImageNet pretrained weights from Torchvision.\n        frozen: Whether to keep all weights frozen during training.\n    \"\"\"\n\n    def __init__(\n        self,\n        name: str = \"resnet50\",\n        visual_feature_size: int = 2048,\n        pretrained: bool = False,\n        frozen: bool = False,\n    ):\n        super().__init__(visual_feature_size)\n\n        self.cnn = getattr(torchvision.models, name)(\n            pretrained, zero_init_residual=True\n        )\n        # Do nothing after the final residual stage.\n        self.cnn.fc = nn.Identity()\n\n        # Freeze all weights if specified.\n        if frozen:\n            for param in self.cnn.parameters():\n                param.requires_grad = False\n            self.cnn.eval()\n\n    def forward(self, image: torch.Tensor) -> torch.Tensor:\n        r\"\"\"\n        Compute visual features for a batch of input images.\n\n        Args:\n            image: Batch of input images. A tensor of shape ``(batch_size, 3,\n                height, width)``.\n\n        Returns:\n            A tensor of shape ``(batch_size, channels, height, width)``, for\n            example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50.\n        \"\"\"\n\n        for idx, (name, layer) in enumerate(self.cnn.named_children()):\n            out = layer(image) if idx == 0 else layer(out)\n\n            # These are the spatial features we need.\n            if name == \"layer4\":\n                # shape: (batch_size, channels, height, width)\n                return out\n\n    def detectron2_backbone_state_dict(self) -> Dict[str, Any]:\n        r\"\"\"\n        Return state dict of visual backbone which can be loaded with\n        `Detectron2 <https://github.com/facebookresearch/detectron2>`_.\n        This is useful for downstream tasks based on Detectron2 (such as\n        object detection and instance segmentation). This method renames\n        certain parameters from Torchvision-style to Detectron2-style.\n\n        Returns:\n            A dict with three keys: ``{\"model\", \"author\", \"matching_heuristics\"}``.\n            These are necessary keys for loading this state dict properly with\n            Detectron2.\n        \"\"\"\n        # Detectron2 backbones have slightly different module names, this mapping\n        # lists substrings of module names required to be renamed for loading a\n        # torchvision model into Detectron2.\n        DETECTRON2_RENAME_MAPPING: Dict[str, str] = {\n            \"layer1\": \"res2\",\n            \"layer2\": \"res3\",\n            \"layer3\": \"res4\",\n            \"layer4\": \"res5\",\n            \"bn1\": \"conv1.norm\",\n            \"bn2\": \"conv2.norm\",\n            \"bn3\": \"conv3.norm\",\n            \"downsample.0\": \"shortcut\",\n            \"downsample.1\": \"shortcut.norm\",\n        }\n        # Populate this dict by renaming module names.\n        d2_backbone_dict: Dict[str, torch.Tensor] = {}\n\n        for name, param in self.cnn.state_dict().items():\n            for old, new in DETECTRON2_RENAME_MAPPING.items():\n                name = name.replace(old, new)\n\n            # First conv and bn module parameters are prefixed with \"stem.\".\n            if not name.startswith(\"res\"):\n                name = f\"stem.{name}\"\n\n            d2_backbone_dict[name] = param\n\n        return {\n            \"model\": d2_backbone_dict,\n            \"__author__\": \"Karan Desai\",\n            \"matching_heuristics\": True,\n        }\n"
  },
  {
    "path": "virtex/optim/__init__.py",
    "content": "from .lookahead import Lookahead\n\n__all__ = [\"Lookahead\"]\n"
  },
  {
    "path": "virtex/optim/lookahead.py",
    "content": "r\"\"\"\n`Lookahead Optimizer: k steps forward, 1 step back <https://arxiv.org/abs/1907.08610>`_.\n\nThis implementation is adapted with minimal modifications from the\n`authors' implementation <https://github.com/michaelrzhang/lookahead>`_.\n\nIf you take it from here, please cite them:\n\n.. code-block:: text\n\n    @inproceedings{zhang2019lookahead,\n        title={Lookahead Optimizer: k steps forward, 1 step back},\n        author={Zhang, Michael R and Lucas, James and Hinton, Geoffrey and Ba, Jimmy},\n        journal={NeurIPS},\n        year={2019}\n    }\n\"\"\"\nfrom collections import defaultdict\nfrom typing import Any, Callable, Dict\n\nimport torch\nfrom torch.optim.optimizer import Optimizer\n\n\nclass Lookahead(Optimizer):\n    r\"\"\"\n    Implements Lookahead optimizer.\n\n    Args:\n        optimizer: Wrapper inner optimizer. The weights it manages will be the\n            \"fast\" weights.\n        k: Number of lookahead steps before updating \"slow\" weights.\n        alpha: Linear interpolation factor, 1.0 recovers inner optimizer.\n    \"\"\"\n\n    def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.8):\n        self.optimizer = optimizer\n        self.k = k\n        self.alpha = alpha\n\n        # Counter for inner optimizer.\n        self._k_counter = 0\n\n        # Cache the current optimizer parameters\n        self.state: Dict[str, Any] = defaultdict(dict)\n        for group in optimizer.param_groups:\n            for p in group[\"params\"]:\n                param_state = self.state[p]\n                param_state[\"slow_params\"] = torch.zeros_like(p.data)\n                param_state[\"slow_params\"].copy_(p.data)\n\n    def __getstate__(self):\n        return {\n            \"state\": self.state,\n            \"optimizer\": self.optimizer,\n            \"alpha\": self.alpha,\n            \"k\": self.k,\n            \"_k_counter\": self._k_counter,\n        }\n\n    @property\n    def param_groups(self):\n        return self.optimizer.param_groups\n\n    def zero_grad(self):\n        r\"\"\"Clear all grad buffers at the start of new forward pass.\"\"\"\n        self.optimizer.zero_grad()\n\n    def state_dict(self):\n        return self.optimizer.state_dict()\n\n    def load_state_dict(self, state_dict: Dict[str, Any]):\n        self.optimizer.load_state_dict(state_dict)\n\n        # Cache optimizer parameters after loading state dict.\n        for group in self.optimizer.param_groups:\n            for p in group[\"params\"]:\n                param_state = self.state[p]\n                param_state[\"slow_params\"] = torch.zeros_like(p.data)\n                param_state[\"slow_params\"].copy_(p.data)\n\n    def step(self, closure: Callable = None):\n        r\"\"\"\n        Perform a single Lookahead optimization step.\n\n        Args:\n            closure: A callable that re-evaluates the model and returns loss.\n        \"\"\"\n        loss = self.optimizer.step(closure)\n        self._k_counter += 1\n\n        if self._k_counter >= self.k:\n            self._k_counter = 0\n            # Lookahead and cache the current optimizer parameters\n            for group in self.optimizer.param_groups:\n                for p in group[\"params\"]:\n                    param_state = self.state[p]\n                    p.data.mul_(self.alpha).add_(\n                        param_state[\"slow_params\"], alpha=1.0 - self.alpha\n                    )\n                    param_state[\"slow_params\"].copy_(p.data)\n        return loss\n\n    def load_slow_weights(self):\n        r\"\"\"\n        Load slow weights from Lookahead optimizer. Useful for performing\n        evaluation on the slow weights (which typically generalize better).\n\n        This method backs up fast weights to load them after evaluation. No\n        need to call this method if evaluation happens just after a lookahead\n        step.\n        \"\"\"\n        for group in self.optimizer.param_groups:\n            for p in group[\"params\"]:\n                param_state = self.state[p]\n                param_state[\"backup_params\"] = torch.zeros_like(p.data)\n                param_state[\"backup_params\"].copy_(p.data)\n                p.data.copy_(param_state[\"slow_params\"])\n\n    def restore_fast_weights(self):\n        r\"\"\"\n        Restore fast weights for optimization. Call this after evaluation if\n        :meth:`load_slow_weights` was called.\n        \"\"\"\n        for group in self.optimizer.param_groups:\n            for p in group[\"params\"]:\n                param_state = self.state[p]\n                p.data.copy_(param_state[\"backup_params\"])\n                del param_state[\"backup_params\"]\n"
  },
  {
    "path": "virtex/optim/lr_scheduler.py",
    "content": "import bisect\nimport math\nfrom typing import List\n\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\n\n\nclass LinearWarmupNoDecayLR(LambdaLR):\n    r\"\"\"\n    A learning rate scheduler which linearly increases learning rate from 0\n    LR, and further keeps it constant throughout training.\n\n    Args:\n        optimizer: Wrapped optimizer.\n        total_steps: Total epochs (or iterations) for training.\n        warmup_steps: Number of first few steps to do linear warmup.\n        last_epoch: The index of last step (epoch or iteration). We named it\n            ``last_epoch`` instead of ``last_step`` to keep the naming consistent\n            with other LR schedulers in PyTorch.\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        total_steps: int,\n        warmup_steps: int,\n        last_epoch: int = -1,\n    ):\n        assert (\n            warmup_steps < total_steps\n        ), \"Warmup steps should be less than total steps.\"\n\n        self.tsteps = total_steps\n        self.wsteps = warmup_steps\n        super().__init__(optimizer, self._lr_multiplier, last_epoch)\n\n    def _lr_multiplier(self, step: int) -> float:\n        multiplier = step / float(max(1, self.wsteps)) if step < self.wsteps else 1\n        return max(0, multiplier)\n\n\nclass LinearWarmupMultiStepLR(LambdaLR):\n    r\"\"\"\n    A learning rate scheduler which linearly increases learning rate from 0\n    LR, and further decreases it by gamma once the number of steps reaches one\n    of the milestones.\n\n    Args:\n        optimizer: Wrapped optimizer.\n        total_steps: Total epochs (or iterations) for training.\n        warmup_steps: Number of first few steps to do linear warmup.\n        last_epoch: The index of last step (epoch or iteration). We named it\n            ``last_epoch`` instead of ``last_step`` to keep the naming consistent\n            with other LR schedulers in PyTorch.\n        milestones: List of step indices (epochs or iterations depending on\n            context). Must be increasing.\n        gamma: Multiplicative factor of learning rate decay.\n        last_epoch: The index of last step (epoch or iteration). We named it\n            ``last_epoch`` instead of ``last_step`` to keep the naming consistent\n            with other LR schedulers in PyTorch.\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        total_steps: int,\n        warmup_steps: int,\n        milestones: List[int],\n        gamma: float = 0.1,\n        last_epoch: int = -1,\n    ):\n        self.wsteps = warmup_steps\n        self.milestones = milestones\n        self.gamma = gamma\n\n        # Keep a track of number of milestones encountered.\n        self.milestones_so_far = 0\n\n        # Common sanity checks.\n        assert milestones == sorted(milestones), \"milestones must be increasing\"\n        assert milestones[0] > warmup_steps, \"first milestone must be after warmup\"\n        assert (\n            milestones[-1] < total_steps\n        ), \"last milestone must be less than total steps\"\n\n        super().__init__(optimizer, self._lr_multiplier, last_epoch)\n\n    def _lr_multiplier(self, step: int) -> float:\n        if step < self.wsteps:\n            # Linear warmup.\n            multiplier = step / float(max(1, self.wsteps))\n        else:\n            # Step decay based on milestones.\n            multiplier = self.gamma ** bisect.bisect_right(self.milestones, step)\n\n        # Avoid negative learning rate.\n        return max(0, multiplier)\n\n\nclass LinearWarmupLinearDecayLR(LambdaLR):\n    r\"\"\"\n    A learning rate scheduler which linearly increases learning rate from 0\n    LR, and further decreases it linearly to zero.\n\n    Args:\n        optimizer: Wrapped optimizer.\n        total_steps: Total epochs (or iterations) for training.\n        warmup_steps: Number of first few steps to do linear warmup.\n        last_epoch: The index of last step (epoch or iteration). We named it\n            ``last_epoch`` instead of ``last_step`` to keep the naming consistent\n            with other LR schedulers in PyTorch.\n    \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        total_steps: int,\n        warmup_steps: int,\n        last_epoch: int = -1,\n    ):\n        assert (\n            warmup_steps < total_steps\n        ), \"Warmup steps should be less than total steps.\"\n\n        self.tsteps = total_steps\n        self.wsteps = warmup_steps\n        super().__init__(optimizer, self._lr_multiplier, last_epoch)\n\n    def _lr_multiplier(self, step: int) -> float:\n        if step < self.wsteps:\n            # Linear warmup.\n            multiplier = step / float(max(1, self.wsteps))\n        else:\n            # Linear decay.\n            multiplier = (self.tsteps - step) / (self.tsteps - self.wsteps)\n        # Avoid negative learning rate.\n        return max(0, multiplier)\n\n\nclass LinearWarmupCosineAnnealingLR(LambdaLR):\n    r\"\"\"\n    A learning rate scheduler which linearly increases learning rate from 0\n    LR, and further decreases it to zero by cosine decay. After linear warmup,\n    the LR decays as:\n\n    .. math::\n        \\eta_t = \\eta_{max}\\cos^2(\\frac{T_{cur} - T_{warm}}{T_{max} - T_{warm}}\\frac{\\pi}{2})\n\n    Args:\n        optimizer: Wrapped optimizer.\n        total_steps: Total epochs (or iterations) for training.\n        warmup_steps: Number of first few steps to do linear warmup.\n        last_epoch: The index of last step (epoch or iteration). We named it\n            ``last_epoch`` instead of ``last_step`` to keep the naming consistent\n            with other LR schedulers in PyTorch.\n   \"\"\"\n\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        total_steps: int,\n        warmup_steps: int,\n        last_epoch: int = -1,\n    ):\n        assert (\n            warmup_steps < total_steps\n        ), \"Warmup steps should be less than total steps.\"\n\n        self.tsteps = total_steps\n        self.wsteps = warmup_steps\n        super().__init__(optimizer, self._lr_multiplier, last_epoch)\n\n    def _lr_multiplier(self, step: int) -> float:\n        if step < self.wsteps:\n            # Linear warmup.\n            multiplier = step / float(max(1, self.wsteps))\n        else:\n            # Cosine annealing decay.\n            cos_factor = (step - self.wsteps) / (self.tsteps - self.wsteps)\n            multiplier = math.cos(cos_factor * (math.pi / 2)) ** 2\n        # Avoid negative learning rate.\n        return max(0, multiplier)\n"
  },
  {
    "path": "virtex/utils/beam_search.py",
    "content": "r\"\"\"\nThis Beam Search implementation is adapted with minor modifications from\n`AllenNLP <https://github.com/allenai/allennlp/blob/master/allennlp/nn/beam_search.py>`_.\n\nThanks to the developers of AllenNLP!\n\n**Update (v1.2):** The \"backpointer\" trick in Beam Search (as implemented in\nAllenNLP) does not work well with autoregressive models (transformers). It is\nnow removed and it improves qualitative predictions and captioning metrics\n(CIDEr/SPICE) for VirTex. Updated captioning results are on ArXiv v3. Refer\n`CHANGELOG <https://github.com/kdexd/virtex/blob/master/CHANGELOG.md>`_ and\n`Release Page <https://github.com/kdexd/virtex/releases/tag/v1.2>`_ for more\ndetails.\n\nHuge thanks to Nicolas Carion (@alcinos) and Aishwarya Kamath (@ashkamath) for\nhelping me fix this bug!\n\"\"\"\nfrom typing import Callable, Tuple\nimport warnings\n\nimport torch\nfrom torch.nn import functional as F\n\n\nclass AutoRegressiveBeamSearch:\n    r\"\"\"\n    Implements the beam search algorithm for decoding the most likely captions.\n\n    Args:\n        eos_index: The index of the end token (``[EOS]``) in vocabulary.\n        max_steps: The maximum number of decoding steps.\n        beam_size: The width of the beam used.\n        per_node_beam_size: The maximum number of candidates to consider per node,\n            at each step in the search. Setting this parameter to a number smaller\n            than ``beam_size`` may give better results, as it can introduce more\n            diversity into the search. See `Beam Search Strategies for Neural\n            Machine Translation. Freitag and Al-Onaizan, 2017 <https://arxiv.org/abs/1702.01806>`_.\n    \"\"\"\n\n    def __init__(\n        self,\n        eos_index: int,\n        max_steps: int = 50,\n        beam_size: int = 5,\n        per_node_beam_size: int = 2,\n    ) -> None:\n        self._eos_index = eos_index\n        self.max_steps = max_steps\n        self.beam_size = beam_size\n        self.per_node_beam_size = per_node_beam_size or beam_size\n\n    def search(\n        self,\n        start_predictions: torch.Tensor,\n        step: Callable[..., torch.Tensor],\n        only_return_best: bool = True,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        r\"\"\"\n        Given a starting state and a step function, apply beam search to find\n        the most likely target captions.\n\n        Args:\n            start_predictions: Tensor containing the initial predictions, shape\n                ``(batch_size, )``. Usually the initial predictions are just the\n                index of the start token (``[SOS]``) in the vocabulary.\n            step: A function that is responsible for computing the next most likely\n                tokens, given the past predictions. Predictions from all previous\n                timesteps are required, not just the last timestep. The function is\n                expected to return a tensor of shape ``(group_size, target_vocab_size)``\n                containing the token logits for the next step.\n            only_return_best: Whether to only return the best beam (with highest\n                logprobs). Set this to ``False`` to return all the beams. If this is\n                ``True``, then the returned tensor is of shape ``(batch_size,\n                sequence_length)``, else will be ``(batch_size, beam_size,\n                sequence_length)``.\n\n        Returns:\n            Tuple of ``(predictions, logprobs)``, where ``predictions``\n            has shape ``(batch_size, beam_size, max_steps)`` and ``logprobs``\n            has shape ``(batch_size, beam_size)``.\n        \"\"\"\n\n        batch_size = start_predictions.size()[0]\n\n        # List of `(batch_size, beam_size, length)` tensors.\n        # Does not include the start symbols, which are implicit.\n        predictions: torch.Tensor = torch.empty(\n            (batch_size, self.beam_size, 0),\n            dtype=torch.long,\n            device=start_predictions.device,\n        )\n        # Calculate the first timestep. This is done outside the main loop\n        # because we are going from a single decoder input (the output from the\n        # encoder) to the top `beam_size` decoder outputs. On the other hand,\n        # within the main loop we are going from the `beam_size` elements of the\n        # beam to `beam_size`^2 candidates from which we will select the top\n        # `beam_size` elements for the next iteration.\n        # shape: (batch_size, num_classes)\n        start_class_logits = step(start_predictions)\n\n        # Convert logits to logprobs.\n        # shape: (batch_size * beam_size, vocab_size)\n        start_class_logprobs = F.log_softmax(start_class_logits, dim=1)\n\n        num_classes = start_class_logprobs.size()[1]\n\n        # shape: (batch_size, beam_size), (batch_size, beam_size)\n        start_top_logprobs, start_predicted_classes = start_class_logprobs.topk(\n            self.beam_size\n        )\n        if self.beam_size == 1 and (start_predicted_classes == self._eos_index).all():\n            warnings.warn(\n                \"Empty captions predicted. You may want to increase beam \"\n                \"size or ensure your step function is working properly.\",\n                RuntimeWarning,\n            )\n            return start_predicted_classes.unsqueeze(-1), start_top_logprobs\n\n        # The log probs for the last time step.\n        # shape: (batch_size, beam_size)\n        last_logprobs = start_top_logprobs\n\n        # shape: (batch_size, beam_size, sequence_length)\n        predictions = torch.cat(\n            [predictions, start_predicted_classes.unsqueeze(-1)], dim=-1\n        )\n\n        # Log probability tensor that mandates that the end token is selected.\n        # shape: (batch_size * beam_size, num_classes)\n        logprobs_after_end = start_class_logprobs.new_full(\n            (batch_size * self.beam_size, num_classes), float(\"-inf\")\n        )\n        logprobs_after_end[:, self._eos_index] = 0.0\n\n        for timestep in range(self.max_steps - 1):\n            # shape: (batch_size * beam_size,)\n            last_predictions = predictions[:, :, -1].reshape(\n                batch_size * self.beam_size\n            )\n\n            # If every predicted token from the last step is `self._eos_index`,\n            # then we can stop early.\n            if (last_predictions == self._eos_index).all():\n                break\n\n            predictions_so_far = predictions.view(batch_size * self.beam_size, -1)\n            # shape: (batch_size * beam_size, num_classes)\n            class_logits = step(predictions_so_far)\n\n            # Convert logits to logprobs.\n            # shape: (batch_size * beam_size, vocab_size)\n            class_logprobs = F.log_softmax(class_logits, dim=1)\n\n            # Set logprobs of last predicted tokens as high negative value to avoid\n            # repetition in caption.\n            for index in range(batch_size * self.beam_size):\n                class_logprobs[index, predictions_so_far[index, -1]] = -10000\n\n            # shape: (batch_size * beam_size, num_classes)\n            last_predictions_expanded = last_predictions.unsqueeze(-1).expand(\n                batch_size * self.beam_size, num_classes\n            )\n            # Here we are finding any beams where we predicted the end token in\n            # the previous timestep and replacing the distribution with a\n            # one-hot distribution, forcing the beam to predict the end token\n            # this timestep as well.\n            # shape: (batch_size * beam_size, num_classes)\n            cleaned_logprobs = torch.where(\n                last_predictions_expanded == self._eos_index,\n                logprobs_after_end,\n                class_logprobs,\n            )\n            # shape (both): (batch_size * beam_size, per_node_beam_size)\n            top_logprobs, predicted_classes = cleaned_logprobs.topk(\n                self.per_node_beam_size\n            )\n            # Here we expand the last log probs to `(batch_size * beam_size,\n            # per_node_beam_size)` so that we can add them to the current log\n            # probs for this timestep. This lets us maintain the log\n            # probability of each element on the beam.\n            # shape: (batch_size * beam_size, per_node_beam_size)\n            expanded_last_logprobs = (\n                last_logprobs.unsqueeze(2)\n                .expand(batch_size, self.beam_size, self.per_node_beam_size)\n                .reshape(batch_size * self.beam_size, self.per_node_beam_size)\n            )\n            # shape: (batch_size * beam_size, per_node_beam_size)\n            summed_top_logprobs = top_logprobs + expanded_last_logprobs\n\n            # shape: (batch_size, beam_size * per_node_beam_size)\n            reshaped_summed = summed_top_logprobs.reshape(\n                batch_size, self.beam_size * self.per_node_beam_size\n            )\n            # shape: (batch_size, beam_size * per_node_beam_size)\n            reshaped_predicted_classes = predicted_classes.reshape(\n                batch_size, self.beam_size * self.per_node_beam_size\n            )\n            # Append the predictions to the current beam.\n            reshaped_beam = (\n                predictions.view(batch_size * self.beam_size, 1, -1)\n                .repeat(1, self.per_node_beam_size, 1)\n                .reshape(batch_size, self.beam_size * self.per_node_beam_size, -1)\n            )\n            reshaped_beam = torch.cat(\n                [reshaped_beam, reshaped_predicted_classes.unsqueeze(-1)], dim=-1\n            )\n\n            # Keep only the top `beam_size` beam indices.\n            # shape: (batch_size, beam_size), (batch_size, beam_size)\n            restricted_beam_logprobs, restricted_beam_indices = reshaped_summed.topk(\n                self.beam_size\n            )\n            predictions = reshaped_beam.gather(\n                1,\n                restricted_beam_indices.unsqueeze(-1).repeat(\n                    1, 1, reshaped_beam.shape[-1]\n                ),\n            )\n\n            # shape: (batch_size, beam_size)\n            last_logprobs = restricted_beam_logprobs\n\n        if not torch.isfinite(last_logprobs).all():\n            warnings.warn(\n                \"Infinite log probs encountered. Some final captions may not \"\n                \"make sense. This can happen when the beam size is larger than\"\n                \" the number of valid (non-zero probability) transitions that \"\n                \"the step function produces.\",\n                RuntimeWarning,\n            )\n\n        # Optionally select best beam and its logprobs.\n        if only_return_best:\n            # shape: (batch_size, sequence_length)\n            predictions = predictions[:, 0, :]\n            last_logprobs = last_logprobs[:, 0]\n\n        return predictions, last_logprobs\n"
  },
  {
    "path": "virtex/utils/checkpointing.py",
    "content": "import copy\nimport pathlib\nfrom typing import Any, Dict, List, Optional\n\nfrom loguru import logger\nimport torch\nfrom torch import nn\n\nimport virtex.utils.distributed as dist\n\n\nclass CheckpointManager:\n    r\"\"\"\n    A helper class to periodically serialize models and other checkpointable\n    objects (optimizers, LR schedulers etc., which implement ``state_dict``\n    method) during training, and optionally record best performing checkpoint\n    based on an observed metric.\n\n    .. note::\n\n        For :class:`~torch.nn.parallel.DistributedDataParallel` objects,\n        ``state_dict`` of internal model is serialized.\n\n    .. note::\n\n        The observed metric for keeping best checkpoint is assumed \"higher is\n        better\", flip the sign if otherwise.\n\n    Args:\n        serialization_dir: Path to a directory to save checkpoints.\n        keep_recent: Number of recent ``k`` checkpoints to keep on disk. Older\n            checkpoints will be removed. Set to a very large value for keeping\n            all checkpoints.\n        checkpointables: Keyword arguments with any checkpointable objects, for\n            example: model, optimizer, learning rate scheduler.\n\n    Examples:\n        >>> model = torch.nn.Linear(10, 2)\n        >>> optimizer = torch.optim.Adam(model.parameters())\n        >>> ckpt_manager = CheckpointManager(\"/tmp\", model=model, optimizer=optimizer)\n        >>> num_epochs = 20\n        >>> for epoch in range(num_epochs):\n        ...     train(model)\n        ...     val_loss = validate(model)\n        ...     ckpt_manager.step(- val_loss, epoch)\n    \"\"\"\n\n    def __init__(\n        self,\n        serialization_dir: str = \"/tmp\",\n        keep_recent: int = 200,\n        **checkpointables: Any,\n    ):\n        self.serialization_dir = pathlib.Path(serialization_dir)\n        self.keep_recent = keep_recent\n\n        # Shallow copy, keeps references to tensors as original objects.\n        self.checkpointables = copy.copy(checkpointables)\n\n        # Initialize members to hold state dict of best checkpoint and its\n        # performance.\n        self._best_metric: float = -1e-12\n        self._best_ckpt: Dict[str, Any] = {}\n\n        # Keep epoch/iteration numbers of recently saved 'k' checkpoints.\n        self._recent_iterations: List[int] = []\n\n    def step(self, iteration: int, metric: Optional[float] = None):\n        r\"\"\"\n        Serialize checkpoint and update best checkpoint based on metric. Keys\n        in serialized checkpoint match those in :attr:`checkpointables`.\n\n        Args:\n            iteration: Current training iteration. Will be saved with other\n                checkpointables.\n            metric: Observed metric (higher is better) for keeping track of the\n                best checkpoint. If this is ``None``, best chckpoint will not be\n                recorded/updated.\n        \"\"\"\n\n        checkpointable_state_dict: Dict[str, Any] = self._state_dict()\n\n        # We also checkpoint current iteration.\n        checkpointable_state_dict[\"iteration\"] = iteration\n\n        # Update the best checkpoint based on metric, if provided.\n        if metric is not None and metric > self._best_metric:\n            self._best_metric = metric\n            self._best_ckpt = copy.copy(checkpointable_state_dict)\n\n        # Serialize checkpoint corresponding to current iteration.\n        torch.save(\n            checkpointable_state_dict,\n            self.serialization_dir / f\"checkpoint_{iteration}.pth\",\n        )\n        if self._best_metric != -1e-12:\n            # Serialize best performing checkpoint observed so far.\n            torch.save(\n                self._best_ckpt, self.serialization_dir / \"checkpoint_best.pth\"\n            )\n\n        # Remove earliest checkpoint if there are more on disk.\n        self._recent_iterations.append(iteration)\n        if len(self._recent_iterations) > self.keep_recent:\n            self.remove_earliest_checkpoint()\n\n    def _state_dict(self):\n        r\"\"\"Return a dict containing state dict of all checkpointables.\"\"\"\n\n        __state_dict: Dict[str, Any] = {}\n        for key in self.checkpointables:\n            if isinstance(\n                self.checkpointables[key], nn.parallel.DistributedDataParallel\n            ):\n                __state_dict[key] = self.checkpointables[key].module.state_dict()\n            else:\n                __state_dict[key] = self.checkpointables[key].state_dict()\n\n        return __state_dict\n\n    def remove_earliest_checkpoint(self):\n        r\"\"\"Remove earliest serialized checkpoint from disk.\"\"\"\n\n        earliest_iteration = self._recent_iterations.pop(0)\n        (self.serialization_dir / f\"checkpoint_{earliest_iteration}.pth\").unlink()\n\n    def load(self, checkpoint_path: str):\n        r\"\"\"\n        Load a serialized checkpoint from a path. This method will try to find\n        each of :attr:`checkpointables` in the file and load its state dict.\n        Since our checkpointables are held as references, this method does not\n        return them.\n\n        Args:\n            checkpoint_path: Path to a checkpoint serialized by :meth:`step`.\n\n        Returns:\n            Iteration corresponding to the loaded checkpoint. Useful for\n            resuming training. This will be -1 in case of best checkpoint,\n            or if info does not exist.\n        \"\"\"\n\n        # Each process will log a message after loading checkpoint.\n        rank = dist.get_rank()\n\n        logger.info(f\"Rank {rank}: Loading checkpoint from {checkpoint_path}\")\n        checkpoint = torch.load(checkpoint_path, map_location=\"cpu\")\n        iteration = checkpoint.pop(\"iteration\", -1)\n\n        # Keep flags of all checkpointables to lo which ones were not loaded.\n        is_loaded = {key: False for key in self.checkpointables}\n\n        # Load each checkpointable from checkpoint.\n        for key in checkpoint:\n            if key in self.checkpointables:\n                logger.info(f\"Rank {rank}: Loading {key} from {checkpoint_path}\")\n\n                if isinstance(\n                    self.checkpointables[key], nn.parallel.DistributedDataParallel\n                ):\n                    self.checkpointables[key].module.load_state_dict(checkpoint[key])\n                else:\n                    self.checkpointables[key].load_state_dict(checkpoint[key])\n\n                is_loaded[key] = True\n            else:\n                logger.info(f\"Rank {rank}: {key} not found in `checkpointables`.\")\n\n        not_loaded: List[str] = [key for key in is_loaded if not is_loaded[key]]\n        if len(not_loaded) > 0:\n            logger.info(\n                f\"Rank {rank}: Checkpointables not found in file: {not_loaded}\"\n            )\n        return iteration\n"
  },
  {
    "path": "virtex/utils/common.py",
    "content": "import argparse\nimport os\nimport random\nimport sys\n\nfrom loguru import logger\nimport numpy as np\nimport torch\n\nfrom virtex.config import Config\nimport virtex.utils.distributed as dist\n\n\ndef cycle(dataloader, device, start_iteration: int = 0):\n    r\"\"\"\n    A generator to yield batches of data from dataloader infinitely.\n\n    Internally, it sets the ``epoch`` for dataloader sampler to shuffle the\n    examples. One may optionally provide the starting iteration to make sure\n    the shuffling seed is different and continues naturally.\n    \"\"\"\n    iteration = start_iteration\n\n    while True:\n        if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler):\n            # Set the `epoch` of DistributedSampler as current iteration. This\n            # is a way of determinisitic shuffling after every epoch, so it is\n            # just a seed and need not necessarily be the \"epoch\".\n            logger.info(f\"Beginning new epoch, setting shuffle seed {iteration}\")\n            dataloader.sampler.set_epoch(iteration)\n\n        for batch in dataloader:\n            for key in batch:\n                batch[key] = batch[key].to(device)\n            yield batch\n            iteration += 1\n\n\ndef common_setup(_C: Config, _A: argparse.Namespace, job_type: str = \"pretrain\"):\n    r\"\"\"\n    Setup common stuff at the start of every pretraining or downstream\n    evaluation job, all listed here to avoid code duplication. Basic steps:\n\n    1. Fix random seeds and other PyTorch flags.\n    2. Set up a serialization directory and loggers.\n    3. Log important stuff such as config, process info (useful during\n        distributed training).\n    4. Save a copy of config to serialization directory.\n\n    .. note::\n\n        It is assumed that multiple processes for distributed training have\n        already been launched from outside. Functions from\n        :mod:`virtex.utils.distributed` module ae used to get process info.\n\n    Args:\n        _C: Config object with all the parameters.\n        _A: Argparse command line arguments.\n        job_type: Type of job for which setup is to be done; one of\n            ``{\"pretrain\", \"downstream\"}``.\n    \"\"\"\n\n    # Get process rank and world size (assuming distributed is initialized).\n    RANK = dist.get_rank()\n    WORLD_SIZE = dist.get_world_size()\n\n    # For reproducibility - refer https://pytorch.org/docs/stable/notes/randomness.html\n    torch.manual_seed(_C.RANDOM_SEED)\n    torch.backends.cudnn.deterministic = _C.CUDNN_DETERMINISTIC\n    torch.backends.cudnn.benchmark = _C.CUDNN_BENCHMARK\n    random.seed(_C.RANDOM_SEED)\n    np.random.seed(_C.RANDOM_SEED)\n\n    # Create serialization directory and save config in it.\n    os.makedirs(_A.serialization_dir, exist_ok=True)\n    _C.dump(os.path.join(_A.serialization_dir, f\"{job_type}_config.yaml\"))\n\n    # Remove default logger, create a logger for each process which writes to a\n    # separate log-file. This makes changes in global scope.\n    logger.remove(0)\n    if dist.get_world_size() > 1:\n        logger.add(\n            os.path.join(_A.serialization_dir, f\"log-rank{RANK}.txt\"),\n            format=\"{time} {level} {message}\",\n        )\n\n    # Add a logger for stdout only for the master process.\n    if dist.is_master_process():\n        logger.add(\n            sys.stdout, format=\"<g>{time}</g>: <lvl>{message}</lvl>\", colorize=True\n        )\n\n    # Print process info, config and args.\n    logger.info(f\"Rank of current process: {RANK}. World size: {WORLD_SIZE}\")\n    logger.info(str(_C))\n\n    logger.info(\"Command line args:\")\n    for arg in vars(_A):\n        logger.info(\"{:<20}: {}\".format(arg, getattr(_A, arg)))\n\n\ndef common_parser(description: str = \"\") -> argparse.ArgumentParser:\n    r\"\"\"\n    Create an argument parser some common arguments useful for any pretraining\n    or downstream evaluation scripts.\n\n    Args:\n        description: Description to be used with the argument parser.\n\n    Returns:\n        A parser object with added arguments.\n    \"\"\"\n    parser = argparse.ArgumentParser(description=description)\n\n    # fmt: off\n    parser.add_argument(\n        \"--config\", metavar=\"FILE\", help=\"Path to a pretraining config file.\"\n    )\n    parser.add_argument(\n        \"--config-override\", nargs=\"*\", default=[],\n        help=\"A list of key-value pairs to modify pretraining config params.\",\n    )\n    parser.add_argument(\n        \"--serialization-dir\", default=\"/tmp/virtex\",\n        help=\"Path to a directory to serialize checkpoints and save job logs.\"\n    )\n\n    group = parser.add_argument_group(\"Compute resource management arguments.\")\n    group.add_argument(\n        \"--cpu-workers\", type=int, default=0,\n        help=\"Number of CPU workers per GPU to use for data loading.\",\n    )\n    group.add_argument(\n        \"--num-machines\", type=int, default=1,\n        help=\"Number of machines used in distributed training.\"\n    )\n    group.add_argument(\n        \"--num-gpus-per-machine\", type=int, default=0,\n        help=\"\"\"Number of GPUs per machine with IDs as (0, 1, 2 ...). Set as\n        zero for single-process CPU training.\"\"\",\n    )\n    group.add_argument(\n        \"--machine-rank\", type=int, default=0,\n        help=\"\"\"Rank of the machine, integer in [0, num_machines). Default 0\n        for training with a single machine.\"\"\",\n    )\n    group.add_argument(\n        \"--dist-url\", default=f\"tcp://127.0.0.1:23456\",\n        help=\"\"\"URL of the master process in distributed training, it defaults\n        to localhost for single-machine training.\"\"\",\n    )\n    # fmt: on\n\n    return parser\n"
  },
  {
    "path": "virtex/utils/distributed.py",
    "content": "r\"\"\"\nA collection of common utilities for distributed training. These are a bunch of\nwrappers over utilities from :mod:`torch.distributed` module, but they do not\nraise exceptions in absence of distributed training / CPU-only training, and\nfall back to sensible default behavior.\n\"\"\"\nfrom typing import Callable, Dict, Tuple, Union\n\nfrom loguru import logger\nimport torch\nfrom torch import distributed as dist\nfrom torch import multiprocessing as mp\n\n\ndef launch(\n    job_fn: Callable,\n    num_machines: int = 1,\n    num_gpus_per_machine: int = 1,\n    machine_rank: int = 0,\n    dist_url: str = \"tcp://127.0.0.1:23456\",\n    args=(),\n):\n    r\"\"\"\n    Launch a job in a distributed fashion: given ``num_machines`` machines,\n    each with ``num_gpus_per_machine`` GPUs, this utility will launch one\n    process per GPU. This wrapper uses :func:`torch.multiprocessing.spawn`.\n\n    The user has to launch one job on each machine, manually specifying a\n    machine rank (incrementing integers from 0), this utility will adjust\n    process ranks per machine. One process on ``machine_rank = 0`` will be\n    refered as the *master process*, and the IP + a free port on this machine\n    will serve as the distributed process communication URL.\n\n    Default arguments imply one machine with one GPU, and communication URL\n    as ``localhost``.\n\n    .. note::\n\n        This utility assumes same number of GPUs per machine with IDs as\n        ``(0, 1, 2 ...)``. If you do not wish to use all GPUs on a machine,\n        set ``CUDA_VISIBLE_DEVICES`` environment variable (for example,\n        ``CUDA_VISIBLE_DEVICES=5,6``, which restricts to GPU 5 and 6 and\n        re-assigns their IDs to 0 and 1 in this job scope).\n\n    Args:\n        job_fn: A callable object to launch. Pass your main function doing\n            training, validation etc. here.\n        num_machines: Number of machines, each with ``num_gpus_per_machine`` GPUs.\n        num_gpus_per_machine: Number of GPUs per machine, with IDs as\n            ``(0, 1, 2 ...)``.\n        machine_rank: A manually specified rank of the machine, serves as a\n            unique identifier and useful for assigning global ranks to processes.\n        dist_url: Disributed process communication URL as ``tcp://x.x.x.x:port``.\n            Set this as the IP (and a free port) of machine with rank 0.\n        args: Arguments to be passed to ``job_fn``.\n    \"\"\"\n\n    assert (\n        torch.cuda.is_available()\n    ), \"CUDA not available, Cannot launch distributed processes.\"\n\n    world_size = num_machines * num_gpus_per_machine\n\n    # Spawn ``num_gpus_per_machine``` processes per machine, and provide\n    # \"local process rank\" (GPU ID) as the first arg to ``_dist_worker``.\n    # fmt: off\n    if world_size > 1:\n        mp.spawn(\n            _job_worker,\n            nprocs=num_gpus_per_machine,\n            args=(\n                job_fn, world_size, num_gpus_per_machine, machine_rank, dist_url, args\n            ),\n            daemon=False,\n        )\n    else:\n        # Default to single machine, single GPU, with ID 0.\n        _job_worker(0, job_fn, 1, 1, 0, dist_url, args)\n    # fmt: on\n\n\ndef _job_worker(\n    local_rank: int,\n    job_fn: Callable,\n    world_size: int,\n    num_gpus_per_machine: int,\n    machine_rank: int,\n    dist_url: str,\n    args: Tuple,\n):\n    r\"\"\"\n    Single distibuted process worker. This should never be used directly,\n    only used by :func:`launch`.\n    \"\"\"\n\n    # Adjust global rank of process based on its machine rank.\n    global_rank = machine_rank * num_gpus_per_machine + local_rank\n    try:\n        dist.init_process_group(\n            backend=\"NCCL\",\n            init_method=dist_url,\n            world_size=world_size,\n            rank=global_rank,\n        )\n    except Exception as e:\n        logger.error(f\"Error launching processes, dist URL: {dist_url}\")\n        raise e\n\n    synchronize()\n    # Set GPU ID for each process according to its rank.\n    torch.cuda.set_device(local_rank)\n    job_fn(*args)\n\n\ndef synchronize() -> None:\n    r\"\"\"Synchronize (barrier) all processes in a process group.\"\"\"\n    if dist.is_initialized():\n        dist.barrier()\n\n\ndef get_world_size() -> int:\n    r\"\"\"Return number of processes in the process group, each uses 1 GPU.\"\"\"\n    return dist.get_world_size() if dist.is_initialized() else 1\n\n\ndef get_rank() -> int:\n    r\"\"\"Return rank of current process in the process group.\"\"\"\n    return dist.get_rank() if dist.is_initialized() else 0\n\n\ndef is_master_process() -> bool:\n    r\"\"\"\n    Check whether current process is the master process. This check is useful\n    to restrict logging and checkpointing to master process. It will always\n    return ``True`` for single machine, single GPU execution.\n    \"\"\"\n    return get_rank() == 0\n\n\ndef average_across_processes(t: Union[torch.Tensor, Dict[str, torch.Tensor]]):\n    r\"\"\"\n    Averages a tensor, or a dict of tensors across all processes in a process\n    group. Objects in all processes will finally have same mean value.\n\n    .. note::\n\n        Nested dicts of tensors are not supported.\n\n    Args:\n        t: torch.Tensor or Dict[str, torch.Tensor]\n            A tensor or dict of tensors to average across processes.\n    \"\"\"\n    if dist.is_initialized():\n        if isinstance(t, torch.Tensor):\n            dist.all_reduce(t, op=dist.ReduceOp.SUM)\n            t /= get_world_size()\n        elif isinstance(t, dict):\n            for k in t:\n                dist.all_reduce(t[k], op=dist.ReduceOp.SUM)\n                t[k] /= dist.get_world_size()\n\n\ndef gpu_mem_usage() -> int:\n    r\"\"\"\n    Return gpu memory usage (in megabytes). If not using GPU, return 0 without\n    raising any exceptions.\n    \"\"\"\n    if torch.cuda.is_available():\n        # This will be in bytes, so we divide by (1024 * 1024).\n        return torch.cuda.max_memory_allocated() // 1048576\n    else:\n        return 0\n"
  },
  {
    "path": "virtex/utils/metrics.py",
    "content": "r\"\"\"\nThis module is a collection of metrics commonly used during pretraining and\ndownstream evaluation. Two main classes here are:\n\n- :class:`TopkAccuracy` used for ImageNet linear classification evaluation.\n- :class:`CocoCaptionsEvaluator` used for caption evaluation (CIDEr and SPICE).\n\nParts of this module (:meth:`tokenize`, :meth:`cider` and :meth:`spice`) are\nadapted from `coco-captions evaluation code <https://github.com/tylin/coco-caption>`_.\n\"\"\"\nfrom collections import defaultdict\nimport json\nimport os\nfrom subprocess import Popen, PIPE, check_call\nimport tempfile\nfrom typing import Any, Dict, List\n\nimport numpy as np\nimport torch\n\n\nclass TopkAccuracy:\n    r\"\"\"\n    Top-K classification accuracy. This class can accumulate per-batch accuracy\n    that can be retrieved at the end of evaluation. Targets and predictions are\n    assumed to be integers (long tensors).\n\n    If used in :class:`~torch.nn.parallel.DistributedDataParallel`, results\n    need to be aggregated across GPU processes outside this class.\n\n    Args:\n        k: ``k`` for computing Top-K accuracy.\n    \"\"\"\n\n    def __init__(self, k: int = 1):\n        self._k = k\n        self.reset()\n\n    def reset(self):\n        self.num_total = 0.0\n        self.num_correct = 0.0\n\n    def __call__(self, predictions: torch.Tensor, ground_truth: torch.Tensor):\n        r\"\"\"\n        Record the accuracy of current batch of predictions and ground-truth.\n\n        Args:\n            predictions: Model predictions - logits or probabilities. Tensor of\n                shape ``(num_classes, )`` (not batched) or ``(B, num_classes)``.\n            ground_truth: Ground-truth integer labels. A scalar tensor or a batch\n                tensor of shape ``(B, )`` with values in ``[0, num_classes-1]``.\n\n        Returns:\n            Accuracy (in percentage) so far.\n        \"\"\"\n\n        # Get top-K predictions (based on scores).\n        if self._k == 1:\n            topk_preds = predictions.max(-1)[1].unsqueeze(-1)\n        else:\n            topk_preds = predictions.topk(min(self._k, predictions.shape[-1]), -1)[1]\n\n        correct = topk_preds.eq(ground_truth.unsqueeze(-1)).float()\n\n        self.num_total += ground_truth.numel()\n        self.num_correct += correct.sum()\n\n        return self.get_result()\n\n    def get_result(self):\n        # Prevent division by zero.\n        return self.num_correct / (self.num_total + 1e-12) * 100\n\n\nclass CocoCaptionsEvaluator:\n    r\"\"\"A helper class to evaluate caption predictions in COCO format. This uses\n    :meth:`cider` and :meth:`spice` which exactly follow original COCO Captions\n    evaluation protocol.\n\n    Args:\n        gt_annotations_path: Path to ground truth annotations in COCO format\n            (typically this would be COCO Captions ``val2017`` split).\n    \"\"\"\n\n    def __init__(self, gt_annotations_path: str):\n        gt_annotations = json.load(open(gt_annotations_path))[\"annotations\"]\n\n        # Keep a mapping from image id to a list of captions.\n        self.ground_truth: Dict[int, List[str]] = defaultdict(list)\n        for ann in gt_annotations:\n            self.ground_truth[ann[\"image_id\"]].append(ann[\"caption\"])\n\n        self.ground_truth = tokenize(self.ground_truth)\n\n    def evaluate(self, preds: List[Dict[str, Any]]) -> Dict[str, float]:\n        r\"\"\"Compute CIDEr and SPICE scores for predictions.\n\n        Args:\n            preds: List of per instance predictions in COCO Captions format:\n                ``[ {\"image_id\": int, \"caption\": str} ...]``.\n\n        Returns:\n            Computed metrics; a dict with keys ``{\"CIDEr\", \"SPICE\"}``.\n        \"\"\"\n        if isinstance(preds, str):\n            preds = json.load(open(preds))\n\n        res = {ann[\"image_id\"]: [ann[\"caption\"]] for ann in preds}\n        res = tokenize(res)\n\n        # Remove IDs from predictions which are not in GT.\n        common_image_ids = self.ground_truth.keys() & res.keys()\n        res = {k: v for k, v in res.items() if k in common_image_ids}\n\n        # Add dummy entries for IDs absent in preds, but present in GT.\n        for k in self.ground_truth:\n            res[k] = res.get(k, [\"\"])\n\n        cider_score = cider(res, self.ground_truth)\n        spice_score = spice(res, self.ground_truth)\n\n        return {\"CIDEr\": 100 * cider_score, \"SPICE\": 100 * spice_score}\n\n\ndef tokenize(image_id_to_captions: Dict[int, List[str]]) -> Dict[int, List[str]]:\n    r\"\"\"\n    Given a mapping of image id to a list of corrsponding captions, tokenize\n    captions in place according to Penn Treebank Tokenizer. This method assumes\n    the presence of Stanford CoreNLP JAR file in directory of this module.\n    \"\"\"\n    # Path to the Stanford CoreNLP JAR file.\n    CORENLP_JAR = (\n        \"assets/stanford-corenlp-full-2014-08-27/stanford-corenlp-3.4.1.jar\"\n    )\n\n    # Prepare data for Tokenizer: write captions to a text file, one per line.\n    image_ids = [k for k, v in image_id_to_captions.items() for _ in range(len(v))]\n    sentences = \"\\n\".join(\n        [c.replace(\"\\n\", \" \") for k, v in image_id_to_captions.items() for c in v]\n    )\n    tmp_file = tempfile.NamedTemporaryFile(delete=False)\n    tmp_file.write(sentences.encode())\n    tmp_file.close()\n\n    # fmt: off\n    # Tokenize sentences. We use the JAR file for tokenization.\n    command = [\n        \"java\", \"-cp\", CORENLP_JAR, \"edu.stanford.nlp.process.PTBTokenizer\",\n        \"-preserveLines\", \"-lowerCase\", tmp_file.name\n    ]\n    tokenized_captions = (\n        Popen(command, cwd=os.path.dirname(os.path.abspath(__file__)), stdout=PIPE)\n        .communicate(input=sentences.rstrip())[0]\n        .decode()\n        .split(\"\\n\")\n    )\n    # fmt: on\n    os.remove(tmp_file.name)\n\n    # Map tokenized captions back to their image IDs.\n    # Punctuations to be removed from the sentences (PTB style)).\n    # fmt: off\n    PUNCTS = [\n        \"''\", \"'\", \"``\", \"`\", \"-LRB-\", \"-RRB-\", \"-LCB-\", \"-RCB-\", \".\", \"?\",\n        \"!\", \",\", \":\", \"-\", \"--\", \"...\", \";\",\n    ]\n    # fmt: on\n    image_id_to_tokenized_captions: Dict[int, List[str]] = defaultdict(list)\n    for image_id, caption in zip(image_ids, tokenized_captions):\n        image_id_to_tokenized_captions[image_id].append(\n            \" \".join([w for w in caption.rstrip().split(\" \") if w not in PUNCTS])\n        )\n\n    return image_id_to_tokenized_captions\n\n\ndef cider(\n    predictions: Dict[int, List[str]],\n    ground_truth: Dict[int, List[str]],\n    n: int = 4,\n    sigma: float = 6.0,\n) -> float:\n    r\"\"\"Compute CIDEr score given ground truth captions and predictions.\"\"\"\n\n    # -------------------------------------------------------------------------\n    def to_ngrams(sentence: str, n: int = 4):\n        r\"\"\"Convert a sentence into n-grams and their counts.\"\"\"\n        words = sentence.split()\n        counts = defaultdict(int)  # type: ignore\n        for k in range(1, n + 1):\n            for i in range(len(words) - k + 1):\n                ngram = tuple(words[i : i + k])\n                counts[ngram] += 1\n        return counts\n\n    def counts2vec(cnts, document_frequency, log_reference_length):\n        r\"\"\"Function maps counts of ngram to vector of tfidf weights.\"\"\"\n        vec = [defaultdict(float) for _ in range(n)]\n        length = 0\n        norm = [0.0 for _ in range(n)]\n        for (ngram, term_freq) in cnts.items():\n            df = np.log(max(1.0, document_frequency[ngram]))\n            # tf (term_freq) * idf (precomputed idf) for n-grams\n            vec[len(ngram) - 1][ngram] = float(term_freq) * (\n                log_reference_length - df\n            )\n            # Compute norm for the vector: will be used for computing similarity\n            norm[len(ngram) - 1] += pow(vec[len(ngram) - 1][ngram], 2)\n\n            if len(ngram) == 2:\n                length += term_freq\n        norm = [np.sqrt(nn) for nn in norm]\n        return vec, norm, length\n\n    def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):\n        r\"\"\"Compute the cosine similarity of two vectors.\"\"\"\n        delta = float(length_hyp - length_ref)\n        val = np.array([0.0 for _ in range(n)])\n        for nn in range(n):\n            for (ngram, count) in vec_hyp[nn].items():\n                val[nn] += (\n                    min(vec_hyp[nn][ngram], vec_ref[nn][ngram]) * vec_ref[nn][ngram]\n                )\n\n            val[nn] /= (norm_hyp[nn] * norm_ref[nn]) or 1\n            val[nn] *= np.e ** (-(delta ** 2) / (2 * sigma ** 2))\n        return val\n\n    # -------------------------------------------------------------------------\n\n    ctest = [to_ngrams(predictions[image_id][0]) for image_id in ground_truth]\n    crefs = [\n        [to_ngrams(gt) for gt in ground_truth[image_id]] for image_id in ground_truth\n    ]\n    # Build document frequency and compute IDF.\n    document_frequency = defaultdict(float)\n    for refs in crefs:\n        # refs, k ref captions of one image\n        for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]):\n            document_frequency[ngram] += 1\n\n    # Compute log reference length.\n    log_reference_length = np.log(float(len(crefs)))\n\n    scores = []\n    for test, refs in zip(ctest, crefs):\n        # Compute vector for test captions.\n        vec, norm, length = counts2vec(\n            test, document_frequency, log_reference_length\n        )\n        # Compute vector for ref captions.\n        score = np.array([0.0 for _ in range(n)])\n        for ref in refs:\n            vec_ref, norm_ref, length_ref = counts2vec(\n                ref, document_frequency, log_reference_length\n            )\n            score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)\n\n        score_avg = np.mean(score)\n        score_avg /= len(refs)\n        score_avg *= 10.0\n        scores.append(score_avg)\n\n    return np.mean(scores)\n\n\ndef spice(\n    predictions: Dict[int, List[str]], ground_truth: Dict[int, List[str]]\n) -> float:\n    r\"\"\"Compute SPICE score given ground truth captions and predictions.\"\"\"\n\n    # Prepare temporary input file for the SPICE scorer.\n    input_data = [\n        {\n            \"image_id\": image_id,\n            \"test\": predictions[image_id][0],\n            \"refs\": ground_truth[image_id],\n        }\n        for image_id in ground_truth\n    ]\n    # Create a temporary directory and dump input file to SPICE.\n    temp_dir = tempfile.mkdtemp()\n    INPUT_PATH = os.path.join(temp_dir, \"input_file.json\")\n    OUTPUT_PATH = os.path.join(temp_dir, \"output_file.json\")\n    json.dump(input_data, open(INPUT_PATH, \"w\"))\n\n    # fmt: off\n    # Run the command to execute SPICE jar.\n    CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))\n    SPICE_JAR = f\"{CURRENT_DIR}/assets/SPICE-1.0/spice-1.0.jar\"\n    CACHE_DIR = f\"{CURRENT_DIR}/assets/cache\"\n    os.makedirs(CACHE_DIR, exist_ok=True)\n    spice_cmd = [\n        \"java\", \"-jar\", \"-Xmx8G\", SPICE_JAR, INPUT_PATH,\n        \"-cache\", CACHE_DIR, \"-out\", OUTPUT_PATH, \"-subset\", \"-silent\",\n    ]\n    check_call(spice_cmd, cwd=CURRENT_DIR)\n    # fmt: on\n\n    # Read and process results\n    results = json.load(open(OUTPUT_PATH))\n    image_id_to_scores = {item[\"image_id\"]: item[\"scores\"] for item in results}\n    spice_scores = [\n        np.array(item[\"scores\"][\"All\"][\"f\"]).astype(float) for item in results\n    ]\n    return np.mean(spice_scores)\n"
  },
  {
    "path": "virtex/utils/nucleus_sampling.py",
    "content": "r\"\"\"\nNucleus Sampling was introduced in the paper\n`The Curious Case of Neural Text Degeneration <https://arxiv.org/abs/1904.09751>`_.\nIf you take it from here, make sure to cite them:\n\n.. code-block:: text\n\n    @inproceedings{,\n        title={The Curious Case of Neural Text Degeneration},\n        author={Ari Holtzman and Jan Buys and Li Du and Maxwell Forbes and Yejin Choi},\n        journal={ICLR},\n        year={2020}\n    }\n\nSome core parts of this code are adapted with minor modifications from Thomas Wolf's\ngist: https://gist.githubusercontent.com/thomwolf/1a5a29f6962089e871b94cbd09daf317\n\"\"\"\n\nfrom typing import Callable, List, Tuple\n\nimport torch\nimport torch.nn.functional as F\n\n\nclass AutoRegressiveNucleusSampling:\n    r\"\"\"\n    Implements the nucleus sampling for decoding captions. This class only works\n    for auto-regressive models (Transformer-like), not recurrent models (LSTM-like).\n\n    Args:\n        eos_index: The index of the end token (``[EOS]``) in vocabulary.\n        max_steps: The maximum number of decoding steps.\n        nucleus_size: Size of top-K nucleus for sampling.\n    \"\"\"\n\n    def __init__(\n        self,\n        eos_index: int,\n        max_steps: int = 50,\n        nucleus_size: float = 0.9,\n    ):\n        super().__init__()\n        self._eos_index = eos_index\n        self.max_steps = max_steps\n        self.nucleus_size = nucleus_size\n\n    def search(\n        self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor]\n    ) -> Tuple[torch.Tensor, None]:\n\n        batch_size = start_predictions.size()[0]\n\n        # List of `(batch_size, )` tensors. One for each timestep.\n        # This includes the start-of-sentence tokens, unlike the implementation\n        # in `AutoregressiveBeamSearch`. We will remove them in the end.\n        predictions: List[torch.Tensor] = [start_predictions]\n\n        for timestep in range(self.max_steps):\n            # Get the predictions from last timestep (most recent).\n            # shape: (batch_size, )\n            last_predictions = predictions[-1]\n\n            # If every predicted token from the last step is end-of-sentence token,\n            # then we can stop early.\n            if (last_predictions == self._eos_index).all():\n                break\n\n            # Combine step predictions made so far into one tensor. This is our\n            # \"partial\" caption input to the transformer.\n            # shape: (batch_size, timestep + 1)\n            predictions_so_far = torch.stack(predictions).permute(1, 0)\n\n            # Take a step, get the distribution of logits from next timestep.\n            # shape: (batch_size, num_classes)\n            current_logits = step(predictions_so_far)\n\n            # Sort logits in descending order to determine the nucleus.\n            sorted_logits, sorted_idx = torch.sort(current_logits, descending=True)\n\n            # Get cumulative softmax probabilites. For every instance in batch, a\n            #  variable amount of tokens (N) will consitute the nucleus.\n            # shape: (batch_size, num_classes)\n            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n\n            # Determine indices of tokens at the tail of distribution. These will be\n            # removed from the nucleus.\n            sorted_idx_to_remove = cumulative_probs > self.nucleus_size\n\n            # Shift the indices to the right to keep the first token outside nucleus.\n            sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone()\n            sorted_idx_to_remove[..., 0] = 0\n\n            # Set logits to large negative value to avoid sampling them. Iterate over\n            # the batch of examples.\n            for t in range(current_logits.size()[0]):\n                idx_to_remove = sorted_idx[t][sorted_idx_to_remove[t]]\n                current_logits[t][idx_to_remove] = -1e12\n\n                # Set logits for last predicted token to a large negative value to\n                # avoid repetition.\n                current_logits[t][last_predictions[t]] = -1e12\n\n            # Sample from the filtered distribution.\n            # shape: (batch_size, num_classes)\n            current_probs = F.softmax(current_logits, dim=-1)\n\n            # shape: (batch_size, )\n            current_predictions = torch.multinomial(current_probs, 1)\n            current_predictions = current_predictions.view(batch_size)\n\n            # Set current predicted tokens to be end-of-sentence for instances where\n            # last prediction was also end-of-sentence token.\n            current_predictions[last_predictions == self._eos_index] = self._eos_index\n\n            predictions.append(current_predictions)\n\n        # Remove start-of-sentence token from predictions, and collect them together.\n        # shape: (batch_size, max_steps) .. or could be less than max_steps.\n        all_predictions = torch.stack(predictions[1:]).permute(1, 0)\n\n        # We don't return any logprobs of generated sequence with nucleus sampling,\n        # unlike `AutoregressiveBeamSearch`.\n        return all_predictions, None\n"
  },
  {
    "path": "virtex/utils/timer.py",
    "content": "import time\nfrom typing import Optional\n\n\nclass Timer:\n    r\"\"\"\n    A simple timer to record time per iteration and ETA of training. ETA is\n    estimated by moving window average with fixed window size.\n\n    Args:\n        start_from: Iteration from which counting should be started/resumed.\n        total_iterations: Total number of iterations. ETA will not be tracked\n            (will remain \"N/A\") if this is not provided.\n        window_size: Window size to calculate ETA based on past few iterations.\n    \"\"\"\n\n    def __init__(\n        self,\n        start_from: int = 1,\n        total_iterations: Optional[int] = None,\n        window_size: int = 20,\n    ):\n        # We decrement by 1 because `current_iter` changes increment during\n        # an iteration (for example, will change from 0 -> 1 on iteration 1).\n        self.current_iter = start_from - 1\n        self.total_iters = total_iterations\n\n        self._start_time = time.time()\n        self._times = [0.0] * window_size\n\n    def tic(self) -> None:\n        r\"\"\"Start recording time: call at the beginning of iteration.\"\"\"\n        self._start_time = time.time()\n\n    def toc(self) -> None:\n        r\"\"\"Stop recording time: call at the end of iteration.\"\"\"\n        self._times.append(time.time() - self._start_time)\n        self._times = self._times[1:]\n        self.current_iter += 1\n\n    @property\n    def stats(self) -> str:\n        r\"\"\"Return a single string with current iteration, time and ETA.\"\"\"\n        return (\n            f\"Iter {self.current_iter} | Time: {self._times[-1]:.3f} sec | \"\n            f\"ETA: {self.eta_hhmm}\"\n        )\n\n    @property\n    def eta_hhmm(self) -> str:\n        r\"\"\"Return ETA in the form of ``hh mm`` string.\"\"\"\n        if self.total_iters:\n            avg_time = sum(self._times) / len(self._times)\n            eta_sec = int(avg_time * (self.total_iters - self.current_iter))\n            return f\"{eta_sec // 3600}h {((eta_sec % 3600) // 60):02d}m\"\n        else:\n            return \"N/A\"\n"
  }
]