Repository: kdexd/virtex Branch: master Commit: ae67b23f86ab Files: 99 Total size: 274.8 KB Directory structure: gitextract_cto194sv/ ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── configs/ │ ├── _base_bicaptioning_R_50_L1_H1024.yaml │ ├── backbone_ablations/ │ │ ├── bicaptioning_R_101_L1_H1024.yaml │ │ ├── bicaptioning_R_50W2X_L1_H1024.yaml │ │ └── bicaptioning_R_50_L1_H1024.yaml │ ├── depth_ablations/ │ │ ├── bicaptioning_R_50_L1_H1024.yaml │ │ ├── bicaptioning_R_50_L2_H1024.yaml │ │ ├── bicaptioning_R_50_L3_H1024.yaml │ │ └── bicaptioning_R_50_L4_H1024.yaml │ ├── detectron2/ │ │ ├── _base_faster_rcnn_R_50_C4_BN.yaml │ │ ├── _base_mask_rcnn_R_50_FPN.yaml │ │ ├── coco_segm_default_init_2x.yaml │ │ ├── lvis_segm_default_init_2x.yaml │ │ ├── lvis_segm_imagenet_init_2x.yaml │ │ └── voc_det_default_init_24k.yaml │ ├── downstream/ │ │ ├── imagenet_clf.yaml │ │ ├── inaturalist_clf.yaml │ │ └── voc07_clf.yaml │ ├── task_ablations/ │ │ ├── bicaptioning_R_50_L1_H2048.yaml │ │ ├── captioning_R_50_L1_H2048.yaml │ │ ├── masked_lm_R_50_L1_H2048.yaml │ │ ├── multilabel_classification_R_50.yaml │ │ └── token_classification_R_50.yaml │ └── width_ablations/ │ ├── bicaptioning_R_50_L1_H1024.yaml │ ├── bicaptioning_R_50_L1_H2048.yaml │ ├── bicaptioning_R_50_L1_H512.yaml │ └── bicaptioning_R_50_L1_H768.yaml ├── docs/ │ ├── Makefile │ ├── _templates/ │ │ └── layout.html │ ├── conf.py │ ├── index.rst │ └── virtex/ │ ├── config.rst │ ├── data.datasets.rst │ ├── data.rst │ ├── data.tokenizers.rst │ ├── data.transforms.rst │ ├── factories.rst │ ├── model_zoo.rst │ ├── models.rst │ ├── modules.embedding.rst │ ├── modules.rst │ ├── modules.textual_heads.rst │ ├── modules.visual_backbones.rst │ ├── optim.lookahead.rst │ ├── optim.lr_scheduler.rst │ ├── optim.rst │ ├── usage/ │ │ ├── downstream.rst │ │ ├── model_zoo.rst │ │ ├── pretrain.rst │ │ └── setup_dependencies.rst │ ├── utils.beam_search.rst │ ├── utils.checkpointing.rst │ ├── utils.common.rst │ ├── utils.distributed.rst │ ├── utils.metrics.rst │ ├── utils.rst │ └── utils.timer.rst ├── hubconf.py ├── requirements.txt ├── scripts/ │ ├── build_vocabulary.py │ ├── clf_linear.py │ ├── clf_voc07.py │ ├── eval_captioning.py │ ├── eval_detectron2.py │ └── pretrain_virtex.py ├── setup.py └── virtex/ ├── __init__.py ├── config.py ├── data/ │ ├── __init__.py │ ├── datasets/ │ │ ├── captioning.py │ │ ├── classification.py │ │ ├── coco_captions.py │ │ ├── downstream.py │ │ └── masked_lm.py │ ├── tokenizers.py │ └── transforms.py ├── factories.py ├── model_zoo/ │ ├── __init__.py │ └── model_zoo.py ├── models/ │ ├── __init__.py │ ├── captioning.py │ ├── classification.py │ └── masked_lm.py ├── modules/ │ ├── embedding.py │ ├── textual_heads.py │ └── visual_backbones.py ├── optim/ │ ├── __init__.py │ ├── lookahead.py │ └── lr_scheduler.py └── utils/ ├── beam_search.py ├── checkpointing.py ├── common.py ├── distributed.py ├── metrics.py ├── nucleus_sampling.py └── timer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # Code Editors .vscode .idea # Code linters .mypy_cache # Datasets and preprocessed files data/ !virtex/data # IPython Notebook .ipynb_checkpoints # virtualenv venv/ ENV/ # Temporary scripts to (smoke) test out bits and pieces of code. scripts/test_* # Data (symlinks) directory, model checkpoints, tensorboard logs etc. datasets/ checkpoints/ virtex/utils/assets/ !virtex/data/datasets/ virtex/model_zoo/configs ================================================ FILE: CHANGELOG.md ================================================ CHANGELOG ========= This CHANGELOG file records changes between different arXiv versions of our paper, and the version of this codebase which should be used to reproduce the results in the corresponding arXiv version. View changes between code versions on the [Releases page](https://github.com/kdexd/virtex/releases). ArXiv v1 -> v2 ============== **Code version:** `v1.2`. Fix image captioning results with a modified beam search implementation. _Rest of the downstream task results and pre-trained models are unchanged._ ArXiv v1 -> v2 ============== **Code version:** `v1.0` or `v1.1`. [ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is our CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0). While the core motivation and approach is the same, we have made some minor changes in our experiments and evaluation setup. These slightly improve model performances across the board (within decimals). New models are available in [`v1.0` model zoo](http://kdexd.github.io/virtex/virtex/usage/model_zoo.html), however links to old models in `v0.9` will be active till June 30, 2021. We encourage you to use the new models! We have updated the experiment config files for all changes described below. Experiment Changes ------------------ ### New Feature: Add a new pretraining task for BERT-style _Masked Language Modeling_. Pre-trained model released in Model Zoo. ### Pre-training: - The only change during pre-training is that we do not apply weight decay to LayerNorm and biases in input embedding and transformer layers. We apply weight decay to the biases in output linear layer (before softmax). - Other factors that could affect results: - Use official [albumentations.ColorJitter transform](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ColorJitter) that mimics torchvision ColorJitter transform. Earlier I implemented [my own ColorJitter](https://github.com/kdexd/virtex/blob/c19e7fc9b98e98af82286ed1537b6f588eaeac44/virtex/data/transforms.py#L156) because albumentations didn't have one. - Use PyTorch Native AMP (Automatic Mixed Precision) instead of NVIDIA Apex. ### Downstream Evaluations: 1. **PASCAL VOC 2007 Linear Classification:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-b4405dd4879a48ef1e5b1e2801035909584a5f1f32f63d5e793fb50dee077b97) - Instead of training linear SVMs on 8192-dimensional average pooled features from ResNet-50 (7x7x2048 —> 2x2x2048), like [(Misra et al. 2019)](https://arxiv.org/abs/1905.01235), we directly train SVMs on 2048-dimensional global average pooled features, following recent works like [SwAV (Caron et al. 2020)](https://arxiv.org/abs/2006.09882). - We change the pre-processing: resize shortest edge to 256 pixels, and take center crop of 224 pixels. - These improve VOC mAP by 1-2 points everywhere, and makes SVM training faster. Since we select best checkpoint based on this metric, all results on other downstream tasks also change in `ArXiv v2` (But the trends remain same.) 2. **ImageNet Linear Evaluation:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-d3dea1e7bf97d0cfca4b59a47c0a9bb81e78b8827654fe0258df9ce2c3f5f41c) - Changed random resized crop scale from (20-100%) to (8-100%) for consistency with evaluations in SSL works like MoCo and SwAV. - Use cosine LR decay instead of step decay, following SwAV. Improves accuracy by up to 1%. 3. **iNaturalist Fine-tuning:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-09096da78cfcde3a604ce22d80313f0800225d928cce5ef7334b89a382adfe4d) - This evaluation is left unchanged across ArXiv versions, but we fixd a typo in image pre-processing step, present in publicly released config. 4. **Detectron2 tasks (COCO and LVIS Instance Segmentation, VOC Detection):** - Heavily simplified the script. Updated Detectron2 uses a more memory-efficient SyncBatchNorm and supports AMP. ================================================ FILE: LICENSE ================================================ Copyright (c) 2020, Karan Desai. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ VirTex: Learning Visual Representations from Textual Annotations ================================================================

Karan Desai and Justin Johnson
University of Michigan


**CVPR 2021** [arxiv.org/abs/2006.06666][1] **Model Zoo, Usage Instructions and API docs:** [kdexd.github.io/virtex](https://kdexd.github.io/virtex) VirTex is a pretraining approach which uses semantically dense captions to learn visual representations. We train CNN + Transformers from scratch on COCO Captions, and transfer the CNN to downstream vision tasks including image classification, object detection, and instance segmentation. VirTex matches or outperforms models which use ImageNet for pretraining -- both supervised or unsupervised -- despite using up to 10x fewer images. ![virtex-model](docs/_static/system_figure.jpg) Get the pretrained ResNet-50 visual backbone from our best performing VirTex model in one line *without any installation*! ```python import torch # That's it, this one line only requires PyTorch. model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True) ``` ### Note (For returning users before January 2021): The pretrained models in our model zoo have changed from [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0) onwards. They are slightly better tuned than older models, and reproduce the results in our CVPR 2021 accepted paper ([arXiv v2](https://arxiv.org/abs/2006.06666v2)). Some training and evaluation hyperparams are changed since [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9). Please refer [`CHANGELOG.md`](https://github.com/kdexd/virtex/blob/master/CHANGELOG.md) Usage Instructions ------------------ 1. [How to setup this codebase?][2] 2. [VirTex Model Zoo][3] 3. [How to train your VirTex model?][4] 4. [How to evaluate on downstream tasks?][5] Full documentation is available at [kdexd.github.io/virtex](https://kdexd.github.io/virtex). Citation -------- If you find this code useful, please consider citing: ```text @inproceedings{desai2021virtex, title={{VirTex: Learning Visual Representations from Textual Annotations}}, author={Karan Desai and Justin Johnson}, booktitle={CVPR}, year={2021} } ``` Acknowledgments --------------- We thank Harsh Agrawal, Mohamed El Banani, Richard Higgins, Nilesh Kulkarni and Chris Rockwell for helpful discussions and feedback on the paper. We thank Ishan Misra for discussions regarding PIRL evaluation protocol; Saining Xie for discussions about replicating iNaturalist evaluation as MoCo; Ross Girshick and Yuxin Wu for help with Detectron2 model zoo; Georgia Gkioxari for suggesting the Instance Segmentation pretraining task ablation; and Stefan Lee for suggestions on figure aesthetics. We thank Jia Deng for access to extra GPUs during project development; and UMich ARC-TS team for support with GPU cluster management. Finally, we thank all the Starbucks outlets in Ann Arbor for many hours of free WiFi. This work was partially supported by the Toyota Research Institute (TRI). However, note that this article solely reflects the opinions and conclusions of its authors and not TRI or any other Toyota entity. [1]: https://arxiv.org/abs/2006.06666 [2]: https://kdexd.github.io/virtex/virtex/usage/setup_dependencies.html [3]: https://kdexd.github.io/virtex/virtex/usage/model_zoo.html [4]: https://kdexd.github.io/virtex/virtex/usage/pretrain.html [5]: https://kdexd.github.io/virtex/virtex/usage/downstream.html ================================================ FILE: configs/_base_bicaptioning_R_50_L1_H1024.yaml ================================================ # ----------------------------------------------------------------------------- # Base config: VirTex pretraining for our "base" bicaptioning model: # ResNet-50 + (L = 1, H = 1024) transformer trained for 500K iterations. # ----------------------------------------------------------------------------- RANDOM_SEED: 0 AMP: true CUDNN_BENCHMARK: true CUDNN_DETERMINISTIC: false DATA: ROOT: "datasets/coco" TOKENIZER_MODEL: "datasets/vocab/coco_10k.model" VOCAB_SIZE: 10000 UNK_INDEX: 0 SOS_INDEX: 1 EOS_INDEX: 2 MASK_INDEX: 3 IMAGE_CROP_SIZE: 224 MAX_CAPTION_LENGTH: 30 IMAGE_TRANSFORM_TRAIN: - "random_resized_crop" - "horizontal_flip" - "color_jitter" - "normalize" IMAGE_TRANSFORM_VAL: - "smallest_resize" - "center_crop" - "normalize" MODEL: NAME: "virtex" VISUAL: NAME: "torchvision::resnet50" PRETRAINED: false FROZEN: false TEXTUAL: NAME: "transdec_postnorm::L1_H1024_A16_F4096" DROPOUT: 0.1 DECODER: NAME: "beam_search" BEAM_SIZE: 5 OPTIM: OPTIMIZER_NAME: "sgd" SGD_MOMENTUM: 0.9 WEIGHT_DECAY: 0.0001 LOOKAHEAD: USE: true ALPHA: 0.5 STEPS: 5 BATCH_SIZE: 256 CNN_LR: 0.2 LR: 0.001 NUM_ITERATIONS: 500000 WARMUP_STEPS: 10000 LR_DECAY_NAME: "cosine" NO_DECAY: ".*textual.(embedding|transformer).*(norm.*|bias)" CLIP_GRAD_NORM: 10.0 ================================================ FILE: configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: VISUAL: NAME: "torchvision::resnet101" ================================================ FILE: configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: VISUAL: NAME: "torchvision::wide_resnet50_2" ================================================ FILE: configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" ================================================ FILE: configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" ================================================ FILE: configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: TEXTUAL: NAME: "transdec_postnorm::L2_H1024_A16_F4096" ================================================ FILE: configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: TEXTUAL: NAME: "transdec_postnorm::L3_H1024_A16_F4096" ================================================ FILE: configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: TEXTUAL: NAME: "transdec_postnorm::L4_H1024_A16_F4096" ================================================ FILE: configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml ================================================ # ---------------------------------------------------------------------------- # Train a Faster R-CNN with ResNet-50 and C4 backbone. This config follows # Detectron2 format; and is unrelated with our VirTex configs. Params here # replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722). # ---------------------------------------------------------------------------- INPUT: # Input format will always be RGB, consistent with torchvision. FORMAT: "RGB" MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) MIN_SIZE_TEST: 800 MODEL: META_ARCHITECTURE: "GeneralizedRCNN" # Train all layers end-to-end by default. BACKBONE: NAME: build_resnet_backbone FREEZE_AT: 0 # Fine-tune with SyncBN. # STRIDE_IN_1X1 is False for torchvision-like models. RESNETS: DEPTH: 50 NORM: SyncBN STRIDE_IN_1X1: False RPN: PRE_NMS_TOPK_TEST: 6000 POST_NMS_TOPK_TEST: 1000 # ROI head with extra BN layer after res5 stage. ROI_HEADS: NAME: "Res5ROIHeadsExtraNorm" # ImageNet color mean for torchvision-like models (RGB order). PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] SOLVER: # This is for 8 GPUs, apply linear scaling for 4 GPUs. IMS_PER_BATCH: 16 BASE_LR: 0.02 TEST: PRECISE_BN: ENABLED: True VERSION: 2 ================================================ FILE: configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml ================================================ # ---------------------------------------------------------------------------- # Train a Mask R-CNN with ResNet-50 and FPN backbone. This config follows # Detectron2 format; and is unrelated with our VirTex configs. Params here # replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722). # ---------------------------------------------------------------------------- INPUT: # Input format will always be RGB, consistent with torchvision. FORMAT: "RGB" MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) MIN_SIZE_TEST: 800 MODEL: META_ARCHITECTURE: "GeneralizedRCNN" # Train all layers end-to-end by default. BACKBONE: NAME: "build_resnet_fpn_backbone" FREEZE_AT: 0 # Fine-tune with SyncBN. # STRIDE_IN_1X1 is False for torchvision-like models. RESNETS: DEPTH: 50 NORM: "SyncBN" STRIDE_IN_1X1: False OUT_FEATURES: ["res2", "res3", "res4", "res5"] FPN: IN_FEATURES: ["res2", "res3", "res4", "res5"] ANCHOR_GENERATOR: # One size for each in feature map SIZES: [[32], [64], [128], [256], [512]] # Three aspect ratios (same for all in feature maps) ASPECT_RATIOS: [[0.5, 1.0, 2.0]] RPN: IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] PRE_NMS_TOPK_TRAIN: 2000 PRE_NMS_TOPK_TEST: 1000 POST_NMS_TOPK_TRAIN: 1000 POST_NMS_TOPK_TEST: 1000 ROI_HEADS: NAME: "StandardROIHeads" IN_FEATURES: ["p2", "p3", "p4", "p5"] ROI_BOX_HEAD: NAME: "FastRCNNConvFCHead" NUM_FC: 2 POOLER_RESOLUTION: 7 ROI_MASK_HEAD: NAME: "MaskRCNNConvUpsampleHead" NUM_CONV: 4 POOLER_RESOLUTION: 14 # ImageNet color mean for torchvision-like models (RGB order). # These are in [0-255] range as expected by Detectron2. Rest of our codebase # uses [0-1] range; but both are equivalent and consistent. PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] SOLVER: # This is for 8 GPUs, apply linear scaling for 4 GPUs. IMS_PER_BATCH: 16 BASE_LR: 0.02 TEST: PRECISE_BN: ENABLED: True VERSION: 2 ================================================ FILE: configs/detectron2/coco_segm_default_init_2x.yaml ================================================ # ----------------------------------------------------------------------------- # Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of # these weight init: random, imagenet (torchvision), virtex or MoCo. # ----------------------------------------------------------------------------- _BASE_: "_base_mask_rcnn_R_50_FPN.yaml" DATASETS: TRAIN: ("coco_2017_train",) TEST: ("coco_2017_val",) MODEL: MASK_ON: True # FPN also has SyncBN, as opposed to no norm (usually). FPN: NORM: "SyncBN" # This will be ignored, weights will be loaded manually in the script. WEIGHTS: "" SOLVER: STEPS: (120000, 160000) MAX_ITER: 180000 VERSION: 2 ================================================ FILE: configs/detectron2/lvis_segm_default_init_2x.yaml ================================================ # ----------------------------------------------------------------------------- # Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of # these weight init: random, virtex or MoCo. (ImageNet init config is separate) # ----------------------------------------------------------------------------- _BASE_: "_base_mask_rcnn_R_50_FPN.yaml" DATASETS: TRAIN: ("lvis_v1_train",) TEST: ("lvis_v1_val",) DATALOADER: SAMPLER_TRAIN: "RepeatFactorTrainingSampler" REPEAT_THRESHOLD: 0.001 TEST: DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300. MODEL: MASK_ON: True # FPN also has SyncBN, as opposed to no norm (usually). FPN: NORM: "SyncBN" ROI_HEADS: NUM_CLASSES: 1203 SCORE_THRESH_TEST: 0.0001 # This will be ignored, weights will be loaded manually in the script. WEIGHTS: "" SOLVER: STEPS: (120000, 160000) MAX_ITER: 180000 VERSION: 2 ================================================ FILE: configs/detectron2/lvis_segm_imagenet_init_2x.yaml ================================================ # ----------------------------------------------------------------------------- # Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation # with weights initialized from supervised ImageNet pretraining (torchvision). # Key difference is that fine-tuning here happens with BN frozen. # ----------------------------------------------------------------------------- _BASE_: "_base_mask_rcnn_R_50_FPN.yaml" DATASETS: TRAIN: ("lvis_v1_train",) TEST: ("lvis_v1_val",) DATALOADER: SAMPLER_TRAIN: "RepeatFactorTrainingSampler" REPEAT_THRESHOLD: 0.001 TEST: DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300. MODEL: MASK_ON: True RESNETS: NORM: "FrozenBN" # Do not tune with SyncBN for ImageNet init from LVIS. ROI_HEADS: NUM_CLASSES: 1203 SCORE_THRESH_TEST: 0.0001 # This will be ignored, weights will be loaded manually in the script. WEIGHTS: "" SOLVER: STEPS: (120000, 160000) MAX_ITER: 180000 VERSION: 2 ================================================ FILE: configs/detectron2/voc_det_default_init_24k.yaml ================================================ # ----------------------------------------------------------------------------- # Train a Faster R-CNN with R50-C4 backbone on VOC07+12 detection with any of # these weight init: random, imagenet (torchvision), virtex or MoCo. # ----------------------------------------------------------------------------- _BASE_: "_base_faster_rcnn_R_50_C4_BN.yaml" DATASETS: TRAIN: ("voc_2007_trainval", "voc_2012_trainval") TEST: ("voc_2007_test",) INPUT: MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) MIN_SIZE_TEST: 800 MODEL: MASK_ON: False ROI_HEADS: NUM_CLASSES: 20 # This will be ignored, weights will be loaded manually in the script. WEIGHTS: "" SOLVER: STEPS: (18000, 22000) MAX_ITER: 24000 WARMUP_ITERS: 100 VERSION: 2 ================================================ FILE: configs/downstream/imagenet_clf.yaml ================================================ RANDOM_SEED: 0 # Don't need AMP to train a tiny linear layer. AMP: false CUDNN_BENCHMARK: true CUDNN_DETERMINISTIC: false DATA: ROOT: "datasets/imagenet" IMAGE_TRANSFORM_TRAIN: - "random_resized_crop::{'scale': (0.08, 1.0)}" - "horizontal_flip" - "normalize" IMAGE_TRANSFORM_VAL: - "smallest_resize" - "center_crop" - "normalize" MODEL: VISUAL: FROZEN: true OPTIM: BATCH_SIZE: 256 SGD_MOMENTUM: 0.9 WEIGHT_DECAY: 0.0 NO_DECAY: "none" LOOKAHEAD: USE: false LR: 0.3 WARMUP_STEPS: 0 LR_DECAY_NAME: "cosine" NUM_ITERATIONS: 500500 # 100 epochs ================================================ FILE: configs/downstream/inaturalist_clf.yaml ================================================ RANDOM_SEED: 0 AMP: true CUDNN_BENCHMARK: true CUDNN_DETERMINISTIC: false DATA: ROOT: "datasets/inaturalist" IMAGE_TRANSFORM_TRAIN: - "random_resized_crop::{'scale': (0.08, 1.0)}" - "horizontal_flip" - "normalize" IMAGE_TRANSFORM_VAL: - "smallest_resize" - "center_crop" - "normalize" MODEL: VISUAL: FROZEN: false OPTIM: BATCH_SIZE: 256 SGD_MOMENTUM: 0.9 WEIGHT_DECAY: 0.0001 NO_DECAY: "none" LOOKAHEAD: USE: false LR: 0.025 WARMUP_STEPS: 0 LR_DECAY_NAME: multistep LR_GAMMA: 0.1 LR_STEPS: - 119700 # 70 epochs - 153900 # 90 epochs NUM_ITERATIONS: 171000 # 100 epochs ================================================ FILE: configs/downstream/voc07_clf.yaml ================================================ RANDOM_SEED: 0 DATA: ROOT: datasets/VOC2007 IMAGE_TRANSFORM_TRAIN: - smallest_resize - center_crop - normalize IMAGE_TRANSFORM_VAL: - smallest_resize - center_crop - normalize OPTIM: # Only used for feature extraction, doesn't mean much. BATCH_SIZE: 128 ================================================ FILE: configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: TEXTUAL: NAME: "transdec_postnorm::L1_H2048_A32_F8192" ================================================ FILE: configs/task_ablations/captioning_R_50_L1_H2048.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: NAME: "captioning" TEXTUAL: NAME: "transdec_postnorm::L1_H2048_A32_F8192" ================================================ FILE: configs/task_ablations/masked_lm_R_50_L1_H2048.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: NAME: "masked_lm" TEXTUAL: NAME: "transdec_postnorm::L1_H2048_A32_F8192" ================================================ FILE: configs/task_ablations/multilabel_classification_R_50.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" DATA: VOCAB_SIZE: 81 MODEL: NAME: "multilabel_classification" TEXTUAL: NAME: "none" OPTIM: NO_DECAY: "none" ================================================ FILE: configs/task_ablations/token_classification_R_50.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: NAME: "token_classification" TEXTUAL: NAME: "none" OPTIM: NO_DECAY: "none" ================================================ FILE: configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" ================================================ FILE: configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: TEXTUAL: NAME: "transdec_postnorm::L1_H2048_A32_F8192" ================================================ FILE: configs/width_ablations/bicaptioning_R_50_L1_H512.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: TEXTUAL: NAME: "transdec_postnorm::L1_H512_A8_F2048" ================================================ FILE: configs/width_ablations/bicaptioning_R_50_L1_H768.yaml ================================================ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml" MODEL: TEXTUAL: NAME: "transdec_postnorm::L1_H768_A12_F3072" ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = . BUILDDIR = ../../virtex-sphinx # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/_templates/layout.html ================================================ {% extends "!layout.html" %} {% block htmltitle %} {{ super() }} {% endblock %} ================================================ FILE: docs/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import inspect import os import sys sys.path.insert(0, os.path.abspath("../")) # -- Project information ----------------------------------------------------- project = "virtex" copyright = "2021, Karan Desai and Justin Johnson" author = "Karan Desai" # The full version, including alpha/beta/rc tags release = "1.4" # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.coverage", "sphinx.ext.doctest", "sphinx.ext.linkcode", "sphinx.ext.napoleon", "sphinx.ext.autosummary", "sphinx.ext.coverage", "sphinx.ext.intersphinx", "sphinx.ext.mathjax", "sphinx_copybutton", ] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] source_suffix = ".rst" # The master toctree document. master_doc = "index" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # This version is used underneath the title on the index page. version = "1.4" # The following is used if you need to also include a more detailed version. release = "1.4" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = "en" # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path exclude_patterns = ["_build"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False numpydoc_show_class_members = False # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] # -- Autodoc configuration ------------------------------------------------ autodoc_default_options = { "members": True, "member-order": "bysource", "private-members": True, "show-inheritance": True, } # -- Intersphinx configuration -------------------------------------------- intersphinx_mapping = { "torch": ("https://pytorch.org/docs/stable/", None), "albumentations": ("https://albumentations.readthedocs.io/en/latest/", None), } # -- Miscellaneous Extra Tweaks ------------------------------------------- # make github links resolve def linkcode_resolve(domain, info): """ Determine the URL corresponding to Python object This code is from https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L290 and https://github.com/Lasagne/Lasagne/pull/262 """ if domain != "py": return None modname = info["module"] fullname = info["fullname"] submod = sys.modules.get(modname) if submod is None: return None obj = submod for part in fullname.split("."): try: obj = getattr(obj, part) except: # noqa: E722 return None try: fn = inspect.getsourcefile(obj) except: # noqa: E722 fn = None if not fn: return None try: source, lineno = inspect.getsourcelines(obj) except: # noqa: E722 lineno = None if lineno: linespec = "#L%d-L%d" % (lineno, lineno + len(source) - 1) else: linespec = "" filename = info["module"].replace(".", "/") return f"https://github.com/kdexd/virtex/blob/master/{filename}.py{linespec}" ================================================ FILE: docs/index.rst ================================================ .. raw:: html

VirTex: Learning Visual Representations from Textual Annotations

Karan Desai and Justin Johnson
University of Michigan


Abstract

The de-facto approach to many vision tasks is to start from pretrained visual representations, typically learned via supervised training on ImageNet. Recent methods have explored unsupervised pretraining to scale to vast quantities of unlabeled images. In contrast, we aim to learn high-quality visual representations from fewer images. To this end we revisit supervised pretraining, and seek data-efficient alternatives to classification-based pretraining. We propose VirTex -- a pretraining approach using semantically dense captions to learn visual representations. We train convolutional networks from scratch on COCO Captions, and transfer them to downstream recognition tasks including image classification, object detection, and instance segmentation. On all tasks, VirTex yields features that match or exceed those learned on ImageNet -- supervised or unsupervised -- despite using up to ten times fewer images.

**CVPR 2021. Paper available at:** `arxiv.org/abs/2006.06666 `_. **Code available at:** `github.com/kdexd/virtex `_. .. image:: _static/system_figure.jpg Get the pretrained ResNet-50 visual backbone from our best performing VirTex model in one line *without any installation*! .. code-block:: python import torch # That's it, this one line only requires PyTorch. model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True) More details in :doc:`virtex/usage/model_zoo`. Next, dive deeper into our code with User Guide and API References! User Guide ---------- .. toctree:: :maxdepth: 2 virtex/usage/setup_dependencies virtex/usage/model_zoo virtex/usage/pretrain virtex/usage/downstream API Reference ------------- .. toctree:: :maxdepth: 2 virtex/config virtex/factories virtex/data virtex/models virtex/modules virtex/optim virtex/utils virtex/model_zoo Citation -------- If you find this code useful, please consider citing: .. code-block:: text @inproceedings{desai2021virtex, title={{VirTex: Learning Visual Representations from Textual Annotations}}, author={Karan Desai and Justin Johnson}, booktitle={CVPR}, year={2021} } Acknowledgments --------------- We thank Harsh Agrawal, Mohamed El Banani, Richard Higgins, Nilesh Kulkarni and Chris Rockwell for helpful discussions and feedback on the paper. We thank Ishan Misra for discussions regarding PIRL evaluation protocol; Saining Xie for discussions about replicating iNaturalist evaluation as MoCo; Ross Girshick and Yuxin Wu for help with Detectron2 model zoo; Georgia Gkioxari for suggesting the Instance Segmentation pretraining task ablation; and Stefan Lee for suggestions on figure aesthetics. We thank Jia Deng for access to extra GPUs during project development; and UMich ARC-TS team for support with GPU cluster management. Finally, we thank all the Starbucks outlets in Ann Arbor for many hours of free WiFi. This work was partially supported by the Toyota Research Institute (TRI). However, note that this article solely reflects the opinions and conclusions of its authors and not TRI or any other Toyota entity. Indices and Tables ------------------ * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/virtex/config.rst ================================================ virtex.config ============= .. raw:: html
.. automodule:: virtex.config Config References ----------------- .. literalinclude:: ../../virtex/config.py :language: python :linenos: :lines: 42-210 :dedent: 8 ================================================ FILE: docs/virtex/data.datasets.rst ================================================ virtex.data.datasets ==================== .. raw:: html
Pretraining Datasets -------------------- .. automodule:: virtex.data.datasets.coco_captions .. automodule:: virtex.data.datasets.captioning .. automodule:: virtex.data.datasets.classification ------------------------------------------------------------------------------ Downstream Datasets ------------------- .. automodule:: virtex.data.datasets.downstream ================================================ FILE: docs/virtex/data.rst ================================================ virtex.data =========== .. raw:: html
.. toctree:: data.datasets data.tokenizers data.transforms ================================================ FILE: docs/virtex/data.tokenizers.rst ================================================ virtex.data.tokenizers ====================== .. raw:: html
.. automodule:: virtex.data.tokenizers ================================================ FILE: docs/virtex/data.transforms.rst ================================================ virtex.data.transforms ====================== .. raw:: html
.. automodule:: virtex.data.transforms ================================================ FILE: docs/virtex/factories.rst ================================================ virtex.factories ================ .. raw:: html
.. First only include the top-level module, and base class docstrings. .. automodule:: virtex.factories :no-members: .. autoclass:: virtex.factories.Factory ------------------------------------------------------------------------------ Dataloading-related Factories ----------------------------- .. autoclass:: virtex.factories.TokenizerFactory :members: from_config .. autoclass:: virtex.factories.ImageTransformsFactory :members: from_config .. autoclass:: virtex.factories.PretrainingDatasetFactory :members: from_config .. autoclass:: virtex.factories.DownstreamDatasetFactory :members: from_config ------------------------------------------------------------------------------ Modeling-related Factories -------------------------- .. autoclass:: virtex.factories.VisualBackboneFactory :members: from_config .. autoclass:: virtex.factories.TextualHeadFactory :members: from_config .. autoclass:: virtex.factories.PretrainingModelFactory :members: from_config ------------------------------------------------------------------------------ Optimization-related Factories ------------------------------ .. autoclass:: virtex.factories.OptimizerFactory :members: from_config .. autoclass:: virtex.factories.LRSchedulerFactory :members: from_config ================================================ FILE: docs/virtex/model_zoo.rst ================================================ virtex.model_zoo ================ .. raw:: html
.. automodule:: virtex.model_zoo.model_zoo ================================================ FILE: docs/virtex/models.rst ================================================ virtex.models ============= .. raw:: html
.. automodule:: virtex.models.classification ------------------------------------------------------------------------------- .. automodule:: virtex.models.captioning ------------------------------------------------------------------------------- .. automodule:: virtex.models.masked_lm ================================================ FILE: docs/virtex/modules.embedding.rst ================================================ virtex.modules.embedding ======================== .. raw:: html
.. automodule:: virtex.modules.embedding ================================================ FILE: docs/virtex/modules.rst ================================================ virtex.modules ============== .. raw:: html
.. toctree:: modules.embedding modules.visual_backbones modules.textual_heads ================================================ FILE: docs/virtex/modules.textual_heads.rst ================================================ virtex.modules.textual_heads ============================ .. raw:: html
.. automodule:: virtex.modules.textual_heads ================================================ FILE: docs/virtex/modules.visual_backbones.rst ================================================ virtex.modules.visual_backbones =============================== .. raw:: html
.. automodule:: virtex.modules.visual_backbones ================================================ FILE: docs/virtex/optim.lookahead.rst ================================================ virtex.optim.lookahead ====================== .. raw:: html
.. automodule:: virtex.optim.lookahead ================================================ FILE: docs/virtex/optim.lr_scheduler.rst ================================================ virtex.optim.lr_scheduler ========================= .. raw:: html
.. automodule:: virtex.optim.lr_scheduler ================================================ FILE: docs/virtex/optim.rst ================================================ virtex.optim ============ .. raw:: html
.. toctree:: optim.lookahead optim.lr_scheduler ================================================ FILE: docs/virtex/usage/downstream.rst ================================================ How to evaluate on downstream tasks? ==================================== In our paper, we evaluate our pretrained VirTex models on seven different downstream tasks. Our codebase supports all of these evaluations. Throughout this documentation, we consider a specific example of our VirTex pretrained model being evaluated for ensuring filepath uniformity in the following example command snippets. Paths can be trivially adjusted for any other VirTex model; evaluating the baselines (MoCo, ImageNet-supervised, Random Init) require additional changes in commands, explained in the last sub-section. As an example, consider a pretraining job for our best performing VirTex model (``width_ablations/bicaptioning_R_50_L1_H2048.yaml``). The serialization directory might look something like this: .. code-block:: text /tmp/bicaptioning_R_50_L1_H2048 pretrain_config.yaml log-rank0.txt # stdout/stderr per GPU process log-rank1.txt ... log-rank7.txt checkpoint_2000.pth checkpoint_4000.pth ... checkpoint_498000.pth checkpoint_500000.pth # serialized checkpoints train_captioning_forward/ events.out.* ... # tensorboard logs ... We evaluate all checkpoints on **PASCAL VOC 2007 Linear Classification**, and then evaluate the best checkpoint (here, it was iteration 500000) on all other downstream tasks. PASCAL VOC 2007 Linear Classification ------------------------------------- Evaluate a single VirTex pretrained checkpoint on VOC 2007 ``trainval`` split: .. code-block:: shell python scripts/clf_voc07.py \ --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ --down-config configs/downstream/voc07_clf.yaml \ --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ --weight-init virtex \ --num-gpus-per-machine 1 \ --cpu-workers 4 \ --serialization-dir /tmp/bicaptioning_R_50_L1_H2048 To evaluate recent 100 checkpoints in the sub-directory, this command can be looped over as follows: .. code-block:: shell for ((iter = 300000; iter <= 500000; iter+=2000)); do # add command with `checkpoint_$iter.pth` done This script write metric to tensorboard logs in the same pretraining directory, all VOC07 mAP curves appear together with pretraining loss curves. ------------------------------------------------------------------------------- ImageNet Linear Classification ------------------------------ We train a linear classifier on 2048-dimensional global average pooled features extracted from a frozen visual backbone. Evaluate a checkpoint (for example, iteration 500000) on this task as: .. code-block:: shell python scripts/clf_linear.py \ --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ --down-config configs/downstream/imagenet_clf.yaml \ --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ --weight-init virtex \ --num-gpus-per-machine 8 \ --cpu-workers 4 \ --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/imagenet_500000 \ --checkpoint-every 5005 # 1 epoch of ImageNet ------------------------------------------------------------------------------- Instance Segmentation (and Object Detection) on COCO ---------------------------------------------------- Train a Mask R-CNN with FPN backbone for COCO Instance Segmentation (and Object Detection, because it also has a box head) by initializing the backbone from VirTex pretrained weights: .. code-block:: shell python scripts/eval_detectron2.py \ --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ --d2-config configs/detectron2/coco_segm_default_init_2x.yaml \ --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ --weight-init virtex \ --num-gpus-per-machine 8 \ --cpu-workers 2 \ --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/coco_segm_500000 \ --checkpoint-every 5000 .. note:: 1. This script periodically serializes checkpoints but skips validation step during training for saving time; to evaluate a serialized checkpoint and write results to tensorboard, provide it as ``--checkpoint-path`` and additional flags ``--resume --eval-only``. 2. Note that ``--d2-config`` here is in Detectron2 format, and not our package :class:`~virtex.config.Config`. These points are applicable for all tasks described below. ------------------------------------------------------------------------------- Instance Segmentation on LVIS ----------------------------- Train a Mask R-CNN with FPN backbone for LVIS Instance Segmentation by initializing the backbone from VirTex pretrained weights: .. code-block:: shell python scripts/eval_detectron2.py \ --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ --d2-config configs/detectron2/lvis_segm_default_init_2x.yaml \ --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ --weight-init virtex \ --num-gpus-per-machine 8 \ --cpu-workers 2 \ --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/lvis_segm_500000 \ --checkpoint-every 5000 ------------------------------------------------------------------------------- Object Detection on PASCAL VOC 2007+12 -------------------------------------- Train a Faster R-CNN with C4 backbone for PASCAL VOC 2007+12 Object Detection by initializing the backbone from VirTex pretrained weights: .. code-block:: shell python scripts/eval_detectron2.py \ --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ --d2-config configs/detectron2/voc_det_default_init_24k.yaml \ --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ --weight-init virtex \ --num-gpus-per-machine 8 \ --cpu-workers 2 \ --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/voc_det_500000 \ --checkpoint-every 2500 ------------------------------------------------------------------------------- iNaturalist 2018 Fine-Grained Classification -------------------------------------------- Fine-tune the VirTex pretrained visual backbone end-to-end on iNaturalist 2018 dataset: .. code-block:: shell python scripts/clf_linear.py \ --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ --down-config configs/downstream/inaturalist_clf.yaml \ --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ --weight-init virtex \ --num-gpus-per-machine 8 \ --cpu-workers 4 \ --serialization-dir /tmp/bicaptioning_R_50_L1_H2048/inaturalist_500000 \ --checkpoint-every 1710 # 1 epoch of iNaturalist ------------------------------------------------------------------------------- Image Captioning on COCO Captions val2017 ----------------------------------------- Evaluate a pretrained VirTex model on image captioning for COCO Captions val2017 split (reporting CIDEr and SPICE metics): .. code-block:: shell python scripts/eval_captioning.py \ --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ --calc-metrics \ --num-gpus-per-machine 1 \ --cpu-workers 4 ------------------------------------------------------------------------------- Running Image Captioning Inference on Arbitrary Images ------------------------------------------------------ The above script can be used for generating captions for any images in a directory. Replace certain commands as follows: .. code-block:: shell python scripts/eval_captioning.py \ --config /tmp/bicaptioning_R_50_L1_H2048/pretrain_config.yaml \ --checkpoint-path /tmp/bicaptioning_R_50_L1_H2048/checkpoint_500000.pth \ --data-root /path/to/images_dir \ --output /path/to/save/predictions.json \ --num-gpus-per-machine 1 \ --cpu-workers 4 This script will save predictions in JSON format. Since our goal is to not improve image captioning, these models may not generate the best captions. ================================================ FILE: docs/virtex/usage/model_zoo.rst ================================================ VirTex Model Zoo ================ We provide a collection of pretrained model weights and corresponding config names in this model zoo. Tables contain partial paths to config files for each model, download link for pretrained weights and for reference -- VOC07 mAP and ImageNet top-1 accuracy. The simplest way to download and use a *full* pretrained model (including both, the visual backbone and the textual head) is through :doc:`../model_zoo` API as follows. This code snippet works from anywhere, and does not require to be executed from project root. .. code-block:: python # Get our full best performing VirTex model: import virtex.model_zoo as mz model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True) # Optionally extract the torchvision-like visual backbone (with ``avgpool`` # and ``fc`` layers replaced with ``nn.Identity`` module). cnn = model.visual.cnn Alternatively, weights can be manually downloaded from links below, and this can be executed from the project root: .. code-block:: python from virtex.config import Config from virtex.factories import PretrainingModelFactory from virtex.utils.checkpointing import CheckpointManager # Get the best performing VirTex model: _C = Config("configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml") model = PretrainingModelFactory.from_config(_C) CheckpointManager(model=model).load("/path/to/downloaded/weights.pth") # Optionally extract the torchvision-like visual backbone (with ``avgpool`` # and ``fc`` layers replaced with ``nn.Identity`` module). cnn = model.visual.cnn The pretrained ResNet-50 visual backbone of our best performing model (``width_ablations/bicaptioning_R_50_L1_H2048.yaml``) can be loaded in a single line, *without following any installation steps* (only requires PyTorch v1.5): .. code-block:: python import torch model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True) # This is a torchvision-like resnet50 model, with ``avgpool`` and ``fc`` # layers replaced with ``nn.Identity`` module. image_batch = torch.randn(1, 3, 224, 224) # batch tensor of one image. features_batch = model(image_batch) # shape: (1, 2048, 7, 7) ------------------------------------------------------------------------------- Pretraining Task Ablations ^^^^^^^^^^^^^^^^^^^^^^^^^^ .. raw:: html
Model Config Name VOC07
mAP
ImageNet
Top-1 Acc.
Model URL
task_ablations/bicaptioning_R_50_L1_H2048.yaml 88.7 53.8 model
task_ablations/captioning_R_50_L1_H2048.yaml 88.6 50.8 model
task_ablations/token_classification_R_50.yaml 88.8 48.6 model
task_ablations/multilabel_classification_R_50.yaml 86.2 46.2 model
task_ablations/masked_lm_R_50_L1_H2048.yaml 86.4 46.7 model
Width Ablations ^^^^^^^^^^^^^^^ .. raw:: html
Model Config Name VOC07
mAP
ImageNet
Top-1 Acc.
Model URL
width_ablations/bicaptioning_R_50_L1_H512.yaml 88.4 51.8 model
width_ablations/bicaptioning_R_50_L1_H768.yaml 88.3 52.3 model
width_ablations/bicaptioning_R_50_L1_H1024.yaml 88.3 53.2 model
width_ablations/bicaptioning_R_50_L1_H2048.yaml 88.7 53.8 model
Depth Ablations ^^^^^^^^^^^^^^^ .. raw:: html
Model Config Name VOC07
mAP
ImageNet
Top-1 Acc.
Model URL
depth_ablations/bicaptioning_R_50_L1_H1024.yaml 88.3 53.2 model
depth_ablations/bicaptioning_R_50_L2_H1024.yaml 88.8 53.8 model
depth_ablations/bicaptioning_R_50_L3_H1024.yaml 88.7 53.9 model
depth_ablations/bicaptioning_R_50_L4_H1024.yaml 88.7 53.9 model
Backbone Ablations ^^^^^^^^^^^^^^^^^^ .. raw:: html
Model Config Name VOC07
mAP
ImageNet
Top-1 Acc.
Model URL
backbone_ablations/bicaptioning_R_50_L1_H1024.yaml 88.3 53.2 model
backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml 88.5 52.9 model
backbone_ablations/bicaptioning_R_101_L1_H1024.yaml 88.7 52.1 model
================================================ FILE: docs/virtex/usage/pretrain.rst ================================================ How to train your VirTex model? =============================== We provide training scripts for all type of VirTex models from the paper; including our best-performing model and other ablations. Our training jobs are specified by config files (YAML). Execute all commands from project root to use the provided config files. Training the base VirTex model ------------------------------ Train the base VirTex model with ResNet-50 visual backbone; and a textual head with ``L = 1, H = 1024`` using all default optimization hyperparameters. .. code-block:: python scripts/pretrain_virtex.py \ --config configs/_base_bicaptioning_R_50_L1_H1024.yaml \ --num-gpus-per-machine 8 \ --cpu-workers 4 \ --serialization-dir /tmp/VIRTEX_R_50_L1_H1024 # Default: --checkpoint-every 2000 --log-every 20 Training job will save checkpoints, tensorboard logs (loss curves and metrics), and back up the config in ``--serialization-dir``. Use ``tensorboard --logdir `` to view training curves, validation metrics etc. directly on tensorboard. We recommend training with 8 GPUs on the same machine, although training with multiple GPUs across machines (see: ``--num-machines`` and ``--machine-rank``), single GPU (``--num-gpus-per-machine 1``) as well as CPU (``--num-gpus-per-machine 0``) is also supported. Using multiple GPUs for interactive debugging with PDB is not supported, as PDB and ``multiprocessing`` module do not play nice. ------------------------------------------------------------------------------- Reproducing all VirTex ablations -------------------------------- To reproduce all ablations from the `paper `_, replace the ``--config`` argument in above command with the following (all assumed to be relative to project root): Pretraining Task Ablations ^^^^^^^^^^^^^^^^^^^^^^^^^^ 1. **Bicaptioning:** configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml 2. **Forward Captioning:** configs/task_ablations/captioning_R_50_L1_H2048.yaml 3. **Token Classification:** configs/task_ablations/token_classification_R_50.yaml 4. **Multilabel Classification:** configs/task_ablations/multilabel_classification_R_50.yaml 5. **Masked Language Modeling:** configs/task_ablations/masked_lm_R_50_L1_H2048.yaml Transformer Size Ablations ^^^^^^^^^^^^^^^^^^^^^^^^^^ 1. **Width (H = 512):** configs/width_ablations/bicaptioning_R_50_L1_H512.yaml 2. **Width (H = 768):** configs/width_ablations/bicaptioning_R_50_L1_H768.yaml 3. **Width (H = 1024):** configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml 4. **Width (H = 2048):** configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml 5. **Depth (L = 1):** configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml 6. **Depth (L = 2):** configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml 7. **Depth (L = 3):** configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml 8. **Depth (L = 4):** configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml Backbone Ablations ^^^^^^^^^^^^^^^^^^ 1. **ResNet-50:** configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml 2. **ResNet-50 w2x:** configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml 3. **ResNet-101:** configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml .. note:: **Pretraining Task Ablations** (1), **Transformer Size Ablations** (3 and 5) and **Backbone Ablations** (1) are all the same exact model. ================================================ FILE: docs/virtex/usage/setup_dependencies.rst ================================================ How to setup this codebase? =========================== .. raw:: html
This codebase requires Python 3.6+ or higher. We recommend using Anaconda or Miniconda. We walk through installation and data preprocessing here. Install Dependencies -------------------- For these steps to install through Anaconda (or Miniconda). 1. Install Anaconda or Miniconda distribution based on Python 3+ from their `downloads site `_. 2. Clone the repository first. .. code-block:: shell git clone https://www.github.com/kdexd/virtex 3. Create a conda environment and install all the dependencies. .. code-block:: shell cd virtex conda create -n virtex python=3.8 conda activate virtex pip install -r requirements.txt 4. Install additional packages from Github. .. code-block:: shell pip install git+git://github.com/facebookresearch/fvcore.git#egg=fvcore pip install git+git://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI 5. Install this codebase as a package in development version. .. code-block:: shell python setup.py develop Now you can ``import virtex`` from anywhere as long as you have this conda environment activated. ------------------------------------------------------------------------------- Setup Datasets -------------- Datasets are assumed to exist in ``./datasets`` directory (relative to the project root) following the structure specified below. COCO is used for pretraining, and rest of the datasets (including COCO) are used for downstream tasks. This structure is compatible when using `Detectron2 `_ for downstream tasks. COCO ^^^^ .. code-block:: datasets/coco/ annotations/ captions_{train,val}2017.json instances_{train,val}2017.json train2017/ # images in train2017 split val2017/ # images in val2017 split LVIS ^^^^ .. code-block:: datasets/coco/ train2017/ val2017/ datasets/lvis/ lvis_v1.0_{train,val}.json PASCAL VOC ^^^^^^^^^^ .. code-block:: datasets/VOC2007/ Annotations/ ImageSets/ Main/ trainval.txt test.txt JPEGImages/ datasets/VOC2012/ # Same as VOC2007 above ImageNet ^^^^^^^^ .. code-block:: datasets/imagenet/ train/ # One directory per category with images in it val/ # One directory per category with images in it ILSVRC2012_devkit_t12.tar.gz iNaturalist 2018 ^^^^^^^^^^^^^^^^ .. code-block:: datasets/inaturalist/ train_val2018/ annotations/ train2018.json val2018.json ------------------------------------------------------------------------------- Build vocabulary ---------------- Build a vocabulary out of COCO Captions ``train2017`` split. .. code-block:: shell python scripts/build_vocabulary.py \ --captions datasets/coco/annotations/captions_train2017.json \ --vocab-size 10000 \ --output-prefix datasets/vocab/coco_10k \ --do-lower-case That's it! You are all set to use this codebase. ================================================ FILE: docs/virtex/utils.beam_search.rst ================================================ virtex.utils.beam_search ======================== .. raw:: html
.. automodule:: virtex.utils.beam_search ================================================ FILE: docs/virtex/utils.checkpointing.rst ================================================ virtex.utils.checkpointing ========================== .. raw:: html
.. automodule:: virtex.utils.checkpointing ================================================ FILE: docs/virtex/utils.common.rst ================================================ virtex.utils.common =================== .. raw:: html
.. automodule:: virtex.utils.common ================================================ FILE: docs/virtex/utils.distributed.rst ================================================ virtex.utils.distributed ======================== .. raw:: html
.. automodule:: virtex.utils.distributed ================================================ FILE: docs/virtex/utils.metrics.rst ================================================ virtex.utils.metrics ==================== .. raw:: html
.. automodule:: virtex.utils.metrics ================================================ FILE: docs/virtex/utils.rst ================================================ virtex.utils ============ .. raw:: html
.. toctree:: utils.common utils.distributed utils.timer utils.checkpointing utils.beam_search utils.metrics ================================================ FILE: docs/virtex/utils.timer.rst ================================================ virtex.utils.timer ================== .. raw:: html
.. automodule:: virtex.utils.timer ================================================ FILE: hubconf.py ================================================ dependencies = ["torch"] import torch import torchvision R50_URL = "https://www.dropbox.com/s/pxgjxcva7oypf12/backbone_bicaptioning_R_50_L1_H2048.pth?dl=1" def resnet50(pretrained: bool = False, **kwargs): r""" ResNet-50 visual backbone from the best performing VirTex model: pretrained for bicaptioning on COCO Captions, with textual head ``L = 1, H = 2048``. This is a torchvision-like model, with the last ``avgpool`` and `fc`` modules replaced with ``nn.Identity()`` modules. Given a batch of image tensors with size ``(B, 3, 224, 224)``, this model computes spatial image features of size ``(B, 7, 7, 2048)``, where B = batch size. pretrained (bool): Whether to load model with pretrained weights. """ # Create a torchvision resnet50 with randomly initialized weights. model = torchvision.models.resnet50(pretrained=False, **kwargs) # Replace global average pooling and fully connected layers with identity # modules. model.avgpool = torch.nn.Identity() model.fc = torch.nn.Identity() if pretrained: model.load_state_dict( torch.hub.load_state_dict_from_url(R50_URL, progress=False) ) return model ================================================ FILE: requirements.txt ================================================ albumentations>=1.0 Cython>=0.25 future==0.18.0 loguru>=0.3 lvis>=0.5 numpy>=1.17 opencv-python>=4.2.0 scikit-learn>=1.0 sentencepiece>=0.1.90 torch>=1.9 torchvision>=0.10 tqdm>=4.50.0 ================================================ FILE: scripts/build_vocabulary.py ================================================ import argparse import json import os import tempfile import unicodedata from typing import List import sentencepiece as sp # fmt: off parser = argparse.ArgumentParser( description="""Build a vocabulary out of captions corpus. This vocabulary would be a file which our tokenizer can understand. """ ) parser.add_argument( "-c", "--captions", default="datasets/coco/annotations/captions_train2017.json", help="Path to caption annotations file in COCO format.", ) parser.add_argument( "-s", "--vocab-size", type=int, default=10000, help="Total desired size of our vocabulary.", ) parser.add_argument( "-o", "--output-prefix", default="datasets/vocab/coco_10k", help="Prefix of the files to be saved. Two files will be saved: " "[prefix].model and [prefix].vocab", ) parser.add_argument( "-l", "--do-lower-case", action="store_true", help="Whether to lower case the captions before forming vocabulary.", ) parser.add_argument( "-a", "--keep-accents", action="store_true", help="Whether to keep accents before forming vocabulary (dropped by default).", ) # fmt: on def _read_captions(annotations_path: str) -> List[str]: r""" Given a path to annotation file, read it and return a list of captions. These are not processed by any means, returned from the file as-is. Args: annotations_path: Path to an annotations file containing captions. Returns: List of captions from this annotation file. """ _annotations = json.load(open(annotations_path)) captions: List[str] = [] for ann in _annotations["annotations"]: captions.append(ann["caption"]) return captions if __name__ == "__main__": _A = parser.parse_args() captions: List[str] = _read_captions(_A.captions) # Lower case the captions and remove accents according to arguments. for i, caption in enumerate(captions): caption = caption.lower() if _A.do_lower_case else caption if not _A.keep_accents: caption = unicodedata.normalize("NFKD", caption) caption = "".join( [chr for chr in caption if not unicodedata.combining(chr)] ) captions[i] = caption # Create a temporary directory and dump the captions corpus as a text file # with one caption per line. That's how sentencepiece wants its input. tmpdir_path = tempfile.mkdtemp() with open(os.path.join(tmpdir_path, "captions.txt"), "w") as captions_file: for caption in captions: captions_file.write(caption + "\n") # Padding/out-of-vocab token will be "" and ID 0 by default. # Add [SOS],[EOS] and [MASK] tokens. [MASK] will not be used during # captioning, but good to have to reuse vocabulary across pretext tasks. sp.SentencePieceTrainer.train( f" --input={os.path.join(tmpdir_path, 'captions.txt')}" f" --vocab_size={_A.vocab_size}" f" --model_prefix={_A.output_prefix}" " --model_type=bpe --character_coverage=1.0" " --bos_id=-1 --eos_id=-1" " --control_symbols=[SOS],[EOS],[MASK]" ) ================================================ FILE: scripts/clf_linear.py ================================================ import argparse import os from loguru import logger import torch from torch import nn from torch.cuda import amp from torch.utils.data import DataLoader, DistributedSampler from torch.utils.tensorboard import SummaryWriter from virtex.config import Config from virtex.factories import ( DownstreamDatasetFactory, PretrainingModelFactory, OptimizerFactory, LRSchedulerFactory, ) from virtex.utils.checkpointing import CheckpointManager from virtex.utils.common import common_parser, common_setup, cycle import virtex.utils.distributed as dist from virtex.utils.metrics import TopkAccuracy from virtex.utils.timer import Timer # fmt: off parser = common_parser( description="""Do image classification with linear models and frozen feature extractor, or fine-tune the feature extractor end-to-end.""" ) group = parser.add_argument_group("Downstream config arguments.") group.add_argument( "--down-config", metavar="FILE", help="Path to a downstream config file." ) group.add_argument( "--down-config-override", nargs="*", default=[], help="A list of key-value pairs to modify downstream config params.", ) parser.add_argument_group("Checkpointing and Logging") parser.add_argument( "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], default="virtex", help="""How to initialize weights: 1. 'random' initializes all weights randomly 2. 'imagenet' initializes backbone weights from torchvision model zoo 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path - with 'torchvision', state dict would be from PyTorch's training script. - with 'virtex' it should be for our full pretrained model.""" ) parser.add_argument( "--log-every", type=int, default=50, help="""Log training curves to tensorboard after every these many iterations only master process logs averaged loss values across processes.""", ) parser.add_argument( "--checkpoint-path", help="""Path to load checkpoint and run downstream task evaluation. The name of checkpoint file is required to be `model_*.pth`, where * is iteration number from which the checkpoint was serialized.""" ) parser.add_argument( "--checkpoint-every", type=int, default=5000, help="""Serialize model to a checkpoint after every these many iterations. For ImageNet, (5005 iterations = 1 epoch); for iNaturalist (1710 iterations = 1 epoch).""", ) # fmt: on def main(_A: argparse.Namespace): if _A.num_gpus_per_machine == 0: # Set device as CPU if num_gpus_per_machine = 0. device = torch.device("cpu") else: # Get the current device as set for current distributed process. # Check `launch` function in `virtex.utils.distributed` module. device = torch.cuda.current_device() # Create a downstream config object (this will be immutable) and perform # common setup such as logging and setting up serialization directory. _DOWNC = Config(_A.down_config, _A.down_config_override) common_setup(_DOWNC, _A, job_type="downstream") # Create a (pretraining) config object and backup in serializaion directory. _C = Config(_A.config, _A.config_override) _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) # Get dataset name for tensorboard logging. DATASET = _DOWNC.DATA.ROOT.split("/")[-1] # Set number of output classes according to dataset: NUM_CLASSES_MAPPING = {"imagenet": 1000, "inaturalist": 8142} NUM_CLASSES = NUM_CLASSES_MAPPING[DATASET] # ------------------------------------------------------------------------- # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER # ------------------------------------------------------------------------- train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="train") train_dataloader = DataLoader( train_dataset, batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), num_workers=_A.cpu_workers, sampler=DistributedSampler( train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True, ), drop_last=False, pin_memory=True, collate_fn=train_dataset.collate_fn, ) val_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="val") val_dataloader = DataLoader( val_dataset, batch_size=_DOWNC.OPTIM.BATCH_SIZE // dist.get_world_size(), num_workers=_A.cpu_workers, sampler=DistributedSampler( val_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False, ), pin_memory=True, drop_last=False, collate_fn=val_dataset.collate_fn, ) # Initialize model using pretraining config. pretrained_model = PretrainingModelFactory.from_config(_C) # Load weights according to the init method, do nothing for `random`, and # `imagenet` is already taken care of. if _A.weight_init == "virtex": CheckpointManager(model=pretrained_model).load(_A.checkpoint_path) elif _A.weight_init == "torchvision": # Keep strict=False because this state dict may have weights for # last fc layer. pretrained_model.visual.cnn.load_state_dict( torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], strict=False, ) # Pull out the CNN (torchvision-like) from our pretrained model and add # back the FC layer - this is exists in torchvision models, and is set to # `nn.Identity()` during pretraining. model = pretrained_model.visual.cnn # type: ignore model.fc = nn.Linear(_DOWNC.MODEL.VISUAL.FEATURE_SIZE, NUM_CLASSES).to(device) model = model.to(device) # Re-initialize the FC layer. torch.nn.init.normal_(model.fc.weight.data, mean=0.0, std=0.01) torch.nn.init.constant_(model.fc.bias.data, 0.0) # Freeze all layers except FC as per config param. if _DOWNC.MODEL.VISUAL.FROZEN: # Set model to eval mode to prevent BatchNorm from updating running # mean and std. With only a linear layer, being in eval mode when # training will not matter anyway. model.eval() for name, param in model.named_parameters(): if "fc" not in name: param.requires_grad = False # Cross entropy loss and accuracy meter. criterion = nn.CrossEntropyLoss() top1 = TopkAccuracy(k=1) optimizer = OptimizerFactory.from_config(_DOWNC, model.named_parameters()) scheduler = LRSchedulerFactory.from_config(_DOWNC, optimizer) del pretrained_model # ------------------------------------------------------------------------- # BEFORE TRAINING STARTS # ------------------------------------------------------------------------- # Create a gradient scaler for automatic mixed precision. scaler = amp.GradScaler(enabled=_DOWNC.AMP) # Create an iterator from dataloader to sample batches perpetually. train_dataloader_iter = cycle(train_dataloader, device) if dist.get_world_size() > 1: dist.synchronize() model = nn.parallel.DistributedDataParallel( model, device_ids=[device], find_unused_parameters=True ) if dist.is_master_process(): checkpoint_manager = CheckpointManager( _A.serialization_dir, model=model, optimizer=optimizer, scheduler=scheduler, ) tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) # Keep track of time per iteration and ETA. timer = Timer(start_from=1, total_iterations=_DOWNC.OPTIM.NUM_ITERATIONS) # ------------------------------------------------------------------------- # TRAINING LOOP # ------------------------------------------------------------------------- for iteration in range(1, _DOWNC.OPTIM.NUM_ITERATIONS + 1): timer.tic() optimizer.zero_grad() batch = next(train_dataloader_iter) with amp.autocast(enabled=_DOWNC.AMP): logits = model(batch["image"]) loss = criterion(logits, batch["label"]) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() scheduler.step() timer.toc() if iteration % _A.log_every == 0 and dist.is_master_process(): logger.info( f"{timer.stats} | Loss: {loss:.3f} | GPU: {dist.gpu_mem_usage()} MB" ) tensorboard_writer.add_scalar(f"{DATASET}/train_loss", loss, iteration) tensorboard_writer.add_scalar( f"{DATASET}/learning_rate", optimizer.param_groups[0]["lr"], iteration, ) # --------------------------------------------------------------------- # VALIDATION # --------------------------------------------------------------------- if iteration % _A.checkpoint_every == 0: torch.set_grad_enabled(False) model.eval() total_val_loss = torch.tensor(0.0).to(device) for val_iteration, batch in enumerate(val_dataloader, start=1): for key in batch: batch[key] = batch[key].to(device) logits = model(batch["image"]) loss = criterion(logits, batch["label"]) _ = top1(logits, batch["label"]) total_val_loss += loss # Divide each loss component by number of val batches per GPU. total_val_loss = total_val_loss / val_iteration dist.average_across_processes(total_val_loss) # Get accumulated Top-1 accuracy for logging across GPUs. acc = top1.get_result() top1.reset() dist.average_across_processes(acc) torch.set_grad_enabled(True) # Set model back to train mode only when fine-tuning end-to-end. if not _DOWNC.MODEL.VISUAL.FROZEN: model.train() # Save recent checkpoint and best checkpoint based on accuracy. if dist.is_master_process(): checkpoint_manager.step(iteration) logger.info(f"Iter: {iteration} | Top-1 accuracy: {acc})") tensorboard_writer.add_scalar( f"{DATASET}/val_loss", total_val_loss, iteration ) # This name scoping will result in Tensorboard displaying all # metrics (VOC07, caption, etc.) together. tensorboard_writer.add_scalars( f"metrics/{DATASET}", {"top1": acc}, iteration ) # All processes will wait till master process is done logging. dist.synchronize() if __name__ == "__main__": _A = parser.parse_args() # Add an arg in config override if `--weight-init` is imagenet. if _A.weight_init == "imagenet": _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) if _A.num_gpus_per_machine == 0: main(_A) else: # This will launch `main` and set appropriate CUDA device (GPU ID) as # per process (accessed in the beginning of `main`). dist.launch( main, num_machines=_A.num_machines, num_gpus_per_machine=_A.num_gpus_per_machine, machine_rank=_A.machine_rank, dist_url=_A.dist_url, args=(_A,), ) ================================================ FILE: scripts/clf_voc07.py ================================================ import argparse import multiprocessing as mp import os from typing import Any, List import numpy as np import torch from loguru import logger from sklearn.svm import LinearSVC from sklearn.metrics import average_precision_score from sklearn.model_selection import cross_val_score from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from virtex.config import Config from virtex.factories import PretrainingModelFactory, DownstreamDatasetFactory from virtex.utils.checkpointing import CheckpointManager from virtex.utils.common import common_parser, common_setup parser = common_parser( description="Train SVMs for VOC2007 classification on a pretrained model." ) group = parser.add_argument_group("Downstream config arguments.") group.add_argument( "--down-config", metavar="FILE", help="Path to a downstream config file." ) group.add_argument( "--down-config-override", nargs="*", default=[], help="A list of key-value pairs to modify downstream config params.", ) # fmt: off parser.add_argument_group("Checkpointing") parser.add_argument( "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], default="virtex", help="""How to initialize weights: 1. 'random' initializes all weights randomly 2. 'imagenet' initializes backbone weights from torchvision model zoo 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path - with 'torchvision', state dict would be from PyTorch's training script. - with 'virtex' it should be for our full pretrained model.""" ) parser.add_argument( "--checkpoint-path", help="Path to load checkpoint and run downstream task evaluation." ) # fmt: on def train_test_single_svm(args): feats_train, tgts_train, feats_test, tgts_test, cls_name = args SVM_COSTS = [0.01, 0.1, 1.0, 10.0] cls_labels = np.copy(tgts_train) # Meaning of labels in VOC/COCO original loaded target files: # label 0 = not present, set it to -1 as svm train target # label 1 = present. Make the svm train target labels as -1, 1. cls_labels[np.where(cls_labels == 0)] = -1 # See which cost maximizes the AP for this class. best_crossval_ap: float = 0.0 best_crossval_clf = None best_cost: float = 0.0 # fmt: off for cost in SVM_COSTS: clf = LinearSVC( C=cost, class_weight={1: 2, -1: 1}, penalty="l2", loss="squared_hinge", max_iter=2000, ) ap_scores = cross_val_score( clf, feats_train, cls_labels, cv=3, scoring="average_precision", ) clf.fit(feats_train, cls_labels) # Keep track of best SVM (based on cost) for each class. if ap_scores.mean() > best_crossval_ap: best_crossval_ap = ap_scores.mean() best_crossval_clf = clf best_cost = cost logger.info(f"Best SVM {cls_name}: cost {best_cost}, mAP {best_crossval_ap * 100}") # fmt: on # ------------------------------------------------------------------------- # TEST THE TRAINED SVM (PER CLASS) # ------------------------------------------------------------------------- predictions = best_crossval_clf.decision_function(feats_test) evaluate_data_inds = tgts_test != -1 eval_preds = predictions[evaluate_data_inds] cls_labels = np.copy(tgts_test) eval_cls_labels = cls_labels[evaluate_data_inds] eval_cls_labels[np.where(eval_cls_labels == 0)] = -1 # Binarize class labels to make AP targets. targets = eval_cls_labels > 0 return average_precision_score(targets, eval_preds) def main(_A: argparse.Namespace): if _A.num_gpus_per_machine == 0: # Set device as CPU if num_gpus_per_machine = 0. device = torch.device("cpu") else: # Get the current device (this will be zero here by default). device = torch.cuda.current_device() # Create a downstream config object (this will be immutable) and perform # common setup such as logging and setting up serialization directory. _DOWNC = Config(_A.down_config, _A.down_config_override) common_setup(_DOWNC, _A, job_type="downstream") # Create a (pretraining) config object and backup in serialization directory. _C = Config(_A.config, _A.config_override) _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) # ------------------------------------------------------------------------- # INSTANTIATE DATALOADER, MODEL, AND FEATURE EXTRACTOR # ------------------------------------------------------------------------- train_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="trainval") train_dataloader = DataLoader( train_dataset, batch_size=_DOWNC.OPTIM.BATCH_SIZE, num_workers=_A.cpu_workers, pin_memory=True, ) test_dataset = DownstreamDatasetFactory.from_config(_DOWNC, split="test") test_dataloader = DataLoader( test_dataset, batch_size=_DOWNC.OPTIM.BATCH_SIZE, num_workers=_A.cpu_workers, pin_memory=True, ) NUM_CLASSES = len(train_dataset.class_names) # Initialize from a checkpoint, but only keep the visual module. model = PretrainingModelFactory.from_config(_C) # Load weights according to the init method, do nothing for `random`, and # `imagenet` is already taken care of. if _A.weight_init == "virtex": ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) elif _A.weight_init == "torchvision": # Keep strict=False because this state dict may have weights for # last fc layer. model.visual.cnn.load_state_dict( torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], strict=False, ) # Set ``ITERATION`` to a dummy value. ITERATION = 0 # Transfer model to GPU and set to eval mode. This is a torchvision model # and it returns features as ``(batch_size, 2048, 7, 7)``. model = model.visual.cnn.to(device).eval() # ------------------------------------------------------------------------- # EXTRACT FEATURES FOR TRAINING SVMs # ------------------------------------------------------------------------- features_train: List[torch.Tensor] = [] targets_train: List[torch.Tensor] = [] features_test: List[torch.Tensor] = [] targets_test: List[torch.Tensor] = [] # VOC07 is small, extract all features and keep them in memory. with torch.no_grad(): for batch in tqdm(train_dataloader, desc="Extracting train features:"): features = model(batch["image"].to(device)) # Global average pool features. Assume the tensor is in NCHW format. if len(features.size()) > 2: # shape: (batch_size, visual_feature_size) features = features.mean(dim=(2, 3)) # L2-normalize the global average pooled features. features = F.normalize(features, dim=-1) features_train.append(features.cpu()) targets_train.append(batch["label"]) # Similarly extract test features. for batch in tqdm(test_dataloader, desc="Extracting test features:"): features = model(batch["image"].to(device)) if len(features.size()) > 2: features = features.mean(dim=(2, 3)) features = F.normalize(features, dim=-1) features_test.append(features.cpu()) targets_test.append(batch["label"]) # Convert batches of features/targets to one large numpy array features_train = torch.cat(features_train, dim=0).numpy() targets_train = torch.cat(targets_train, dim=0).numpy().astype(np.int32) features_test = torch.cat(features_test, dim=0).numpy() targets_test = torch.cat(targets_test, dim=0).numpy().astype(np.int32) # ------------------------------------------------------------------------- # TRAIN AND TEST SVMs WITH EXTRACTED FEATURES # ------------------------------------------------------------------------- input_args: List[Any] = [] # Iterate over all VOC07 classes and train one-vs-all linear SVMs. for cls_idx in range(NUM_CLASSES): # fmt: off input_args.append(( features_train, targets_train[:, cls_idx], features_test, targets_test[:, cls_idx], train_dataset.class_names[cls_idx], )) # fmt: on pool = mp.Pool(processes=_A.cpu_workers) pool_output = pool.map(train_test_single_svm, input_args) # ------------------------------------------------------------------------- # TENSORBOARD LOGGING (RELEVANT MAINLY FOR weight_init=checkpoint) # ------------------------------------------------------------------------- # Tensorboard writer for logging mAP scores. This is useful especially # when weight_init=checkpoint (which maybe be coming from a training job). tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) # Test set mAP for each class, for features from every layer. test_map = torch.tensor(pool_output).mean() logger.info(f"Iteration: {ITERATION}, mAP: {test_map * 100}") tensorboard_writer.add_scalars( "metrics/voc07_clf", {f"voc07_mAP": test_map * 100}, ITERATION ) if __name__ == "__main__": _A = parser.parse_args() if _A.num_gpus_per_machine > 1: raise ValueError("Using multiple GPUs is not supported for this script.") # Add an arg in config override if `--weight-init` is imagenet. if _A.weight_init == "imagenet": _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) # No distributed training here, just a single process. main(_A) ================================================ FILE: scripts/eval_captioning.py ================================================ import argparse import json import os from typing import Any, Dict, List from loguru import logger import torch from torch.utils.data import DataLoader from virtex.config import Config from virtex.data import ImageDirectoryDataset from virtex.factories import TokenizerFactory, PretrainingModelFactory from virtex.utils.checkpointing import CheckpointManager from virtex.utils.common import common_parser from virtex.utils.metrics import CocoCaptionsEvaluator # fmt: off parser = common_parser( description="""Run image captioning inference on a pretrained model, and/or evaluate pretrained model on COCO Captions val2017 split.""" ) parser.add_argument( "--images", "--data-root", default=None, help="""Path to a directory containing image files to generate captions for. Default: COCO val2017 image directory as expected relative to project root.""" ) parser.add_argument( "--checkpoint-path", required=True, help="Path to load checkpoint and run captioning evaluation." ) parser.add_argument( "--output", default=None, help="Path to save predictions as a JSON file." ) parser.add_argument( "--calc-metrics", action="store_true", help="""Calculate CIDEr and SPICE metrics using ground truth COCO Captions. This flag should not be set when running inference on arbitrary images.""" ) # fmt: on def main(_A: argparse.Namespace): if _A.num_gpus_per_machine == 0: # Set device as CPU if num_gpus_per_machine = 0. device = torch.device("cpu") else: # Get the current device (this will be zero here by default). device = torch.cuda.current_device() _C = Config(_A.config, _A.config_override) tokenizer = TokenizerFactory.from_config(_C) if _A.data_root is None: _A.data_root = os.path.join(_C.DATA.ROOT, "val2017") val_dataloader = DataLoader( ImageDirectoryDataset(_A.data_root), batch_size=_C.OPTIM.BATCH_SIZE, num_workers=_A.cpu_workers, pin_memory=True, ) # Initialize model from a checkpoint. model = PretrainingModelFactory.from_config(_C).to(device) ITERATION = CheckpointManager(model=model).load(_A.checkpoint_path) model.eval() # Make a list of predictions to evaluate. predictions: List[Dict[str, Any]] = [] for val_iteration, val_batch in enumerate(val_dataloader, start=1): val_batch["image"] = val_batch["image"].to(device) with torch.no_grad(): output_dict = model(val_batch) # Make a dictionary of predictions in COCO format. for image_id, caption in zip( val_batch["image_id"], output_dict["predictions"] ): predictions.append( { # Convert image id to int if possible (mainly for COCO eval). "image_id": int(image_id) if image_id.isdigit() else image_id, "caption": tokenizer.decode(caption.tolist()), } ) logger.info("Displaying first 25 caption predictions:") for pred in predictions[:25]: logger.info(f"{pred['image_id']} :: {pred['caption']}") # Save predictions as a JSON file if specified. if _A.output is not None: os.makedirs(os.path.dirname(_A.output), exist_ok=True) json.dump(predictions, open(_A.output, "w")) logger.info(f"Saved predictions to {_A.output}") # Calculate CIDEr and SPICE metrics using ground truth COCO Captions. This # should be skipped when running inference on arbitrary images. if _A.calc_metrics: # Assume ground truth (COCO val2017 annotations) exist. gt = os.path.join(_C.DATA.ROOT, "annotations", "captions_val2017.json") metrics = CocoCaptionsEvaluator(gt).evaluate(predictions) logger.info(f"Iter: {ITERATION} | Metrics: {metrics}") if __name__ == "__main__": _A = parser.parse_args() if _A.num_gpus_per_machine > 1: raise ValueError("Using multiple GPUs is not supported for this script.") # No distributed training here, just a single process. main(_A) ================================================ FILE: scripts/eval_detectron2.py ================================================ """ Finetune a pre-trained model on a downstream task, one of those available in Detectron2. Supported downstream: - LVIS Instance Segmentation - COCO Instance Segmentation - Pascal VOC 2007+12 Object Detection Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py Thanks to the developers of Detectron2! """ import argparse import os import re import torch from torch.utils.tensorboard import SummaryWriter import detectron2 as d2 from detectron2.checkpoint import DetectionCheckpointer from detectron2.engine import DefaultTrainer, default_setup from detectron2.evaluation import ( LVISEvaluator, PascalVOCDetectionEvaluator, COCOEvaluator, ) from detectron2.modeling.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads from virtex.config import Config from virtex.factories import PretrainingModelFactory from virtex.utils.checkpointing import CheckpointManager from virtex.utils.common import common_parser import virtex.utils.distributed as dist # fmt: off parser = common_parser( description="Train object detectors from pretrained visual backbone." ) parser.add_argument( "--d2-config", required=True, help="Path to a detectron2 config for downstream task finetuning." ) parser.add_argument( "--d2-config-override", nargs="*", default=[], help="""Key-value pairs from Detectron2 config to override from file. Some keys will be ignored because they are set from other args: [DATALOADER.NUM_WORKERS, SOLVER.EVAL_PERIOD, SOLVER.CHECKPOINT_PERIOD, TEST.EVAL_PERIOD, OUTPUT_DIR]""", ) parser.add_argument_group("Checkpointing and Logging") parser.add_argument( "--weight-init", choices=["random", "imagenet", "torchvision", "virtex"], default="virtex", help="""How to initialize weights: 1. 'random' initializes all weights randomly 2. 'imagenet' initializes backbone weights from torchvision model zoo 3. {'torchvision', 'virtex'} load state dict from --checkpoint-path - with 'torchvision', state dict would be from PyTorch's training script. - with 'virtex' it should be for our full pretrained model.""" ) parser.add_argument( "--checkpoint-path", help="Path to load checkpoint and run downstream task evaluation." ) parser.add_argument( "--resume", action="store_true", help="""Specify this flag when resuming training from a checkpoint saved by Detectron2.""" ) parser.add_argument( "--eval-only", action="store_true", help="Skip training and evaluate checkpoint provided at --checkpoint-path.", ) parser.add_argument( "--checkpoint-every", type=int, default=5000, help="Serialize model to a checkpoint after every these many iterations.", ) # fmt: on @ROI_HEADS_REGISTRY.register() class Res5ROIHeadsExtraNorm(Res5ROIHeads): r""" ROI head with ``res5`` stage followed by a BN layer. Used with Faster R-CNN C4/DC5 backbones for VOC detection. """ def _build_res5_block(self, cfg): seq, out_channels = super()._build_res5_block(cfg) norm = d2.layers.get_norm(cfg.MODEL.RESNETS.NORM, out_channels) seq.add_module("norm", norm) return seq, out_channels def build_detectron2_config(_C: Config, _A: argparse.Namespace): r"""Build detectron2 config based on our pre-training config and args.""" _D2C = d2.config.get_cfg() # Override some default values based on our config file. _D2C.merge_from_file(_A.d2_config) _D2C.merge_from_list(_A.d2_config_override) # Set some config parameters from args. _D2C.DATALOADER.NUM_WORKERS = _A.cpu_workers _D2C.SOLVER.CHECKPOINT_PERIOD = _A.checkpoint_every _D2C.OUTPUT_DIR = _A.serialization_dir # Set ResNet depth to override in Detectron2's config. _D2C.MODEL.RESNETS.DEPTH = int( re.search(r"resnet(\d+)", _C.MODEL.VISUAL.NAME).group(1) if "torchvision" in _C.MODEL.VISUAL.NAME else re.search(r"_R_(\d+)", _C.MODEL.VISUAL.NAME).group(1) if "detectron2" in _C.MODEL.VISUAL.NAME else 0 ) return _D2C class DownstreamTrainer(DefaultTrainer): r""" Extension of detectron2's ``DefaultTrainer``: custom evaluator and hooks. Arguments: cfg (detectron2.config.CfgNode): Detectron2 config object. weights (Union[str, Dict]): Weights to load in the initialized model. If ``str``, then we assume path to a checkpoint, or if a ``dict``, we assume a state dict. This will be an ``str`` only if training is resumed from a Detectron2 checkpoint. """ def __init__(self, cfg, weights): super().__init__(cfg) # Load pre-trained weights before wrapping to DDP because `ApexDDP` has # some weird issue with `DetectionCheckpointer`. # fmt: off if isinstance(weights, str): # weights are ``str`` means ImageNet init or resume training. self.start_iter = ( DetectionCheckpointer( self._trainer.model, optimizer=self._trainer.optimizer, scheduler=self.scheduler ).resume_or_load(weights, resume=True).get("iteration", -1) + 1 ) elif isinstance(weights, dict): # weights are a state dict means our pretrain init. DetectionCheckpointer(self._trainer.model)._load_model(weights) # fmt: on @classmethod def build_evaluator(cls, cfg, dataset_name, output_folder=None): if output_folder is None: output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") evaluator_list = [] evaluator_type = d2.data.MetadataCatalog.get(dataset_name).evaluator_type if evaluator_type == "pascal_voc": return PascalVOCDetectionEvaluator(dataset_name) elif evaluator_type == "coco": return COCOEvaluator(dataset_name, cfg, True, output_folder) elif evaluator_type == "lvis": return LVISEvaluator(dataset_name, cfg, True, output_folder) def test(self, cfg=None, model=None, evaluators=None): r"""Evaluate the model and log results to stdout and tensorboard.""" cfg = cfg or self.cfg model = model or self.model tensorboard_writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR) results = super().test(cfg, model) flat_results = d2.evaluation.testing.flatten_results_dict(results) for k, v in flat_results.items(): tensorboard_writer.add_scalar(k, v, self.start_iter) def main(_A: argparse.Namespace): # Local process group is needed for detectron2. pg = list(range(dist.get_world_size())) d2.utils.comm._LOCAL_PROCESS_GROUP = torch.distributed.new_group(pg) # Create a config object (this will be immutable) and perform common setup # such as logging and setting up serialization directory. if _A.weight_init == "imagenet": _A.config_override.extend(["MODEL.VISUAL.PRETRAINED", True]) _C = Config(_A.config, _A.config_override) # We use `default_setup` from detectron2 to do some common setup, such as # logging, setting up serialization etc. For more info, look into source. _D2C = build_detectron2_config(_C, _A) default_setup(_D2C, _A) # Prepare weights to pass in instantiation call of trainer. if _A.weight_init in {"virtex", "torchvision"}: if _A.resume: # If resuming training, let detectron2 load weights by providing path. model = None weights = _A.checkpoint_path else: # Load backbone weights from VirTex pretrained checkpoint. model = PretrainingModelFactory.from_config(_C) if _A.weight_init == "virtex": CheckpointManager(model=model).load(_A.checkpoint_path) else: model.visual.cnn.load_state_dict( torch.load(_A.checkpoint_path, map_location="cpu")["state_dict"], strict=False, ) weights = model.visual.detectron2_backbone_state_dict() else: # If random or imagenet init, just load weights after initializing model. model = PretrainingModelFactory.from_config(_C) weights = model.visual.detectron2_backbone_state_dict() # Back up pretrain config and model checkpoint (if provided). _C.dump(os.path.join(_A.serialization_dir, "pretrain_config.yaml")) if _A.weight_init == "virtex" and not _A.resume: torch.save( model.state_dict(), os.path.join(_A.serialization_dir, "pretrain_model.pth"), ) del model trainer = DownstreamTrainer(_D2C, weights) trainer.test() if _A.eval_only else trainer.train() if __name__ == "__main__": _A = parser.parse_args() # This will launch `main` and set appropriate CUDA device (GPU ID) as # per process (accessed in the beginning of `main`). dist.launch( main, num_machines=_A.num_machines, num_gpus_per_machine=_A.num_gpus_per_machine, machine_rank=_A.machine_rank, dist_url=_A.dist_url, args=(_A, ), ) ================================================ FILE: scripts/pretrain_virtex.py ================================================ import argparse from collections import Counter from typing import Any from loguru import logger import torch from torch import nn from torch.cuda import amp from torch.utils.data import DataLoader, DistributedSampler from torch.utils.tensorboard import SummaryWriter # fmt: off from virtex.config import Config from virtex.factories import ( PretrainingDatasetFactory, PretrainingModelFactory, OptimizerFactory, LRSchedulerFactory, ) from virtex.utils.checkpointing import CheckpointManager from virtex.utils.common import common_parser, common_setup, cycle import virtex.utils.distributed as dist from virtex.utils.timer import Timer parser = common_parser( description="Train a VirTex model (CNN + Transformer) on COCO Captions." ) group = parser.add_argument_group("Checkpointing and Logging") group.add_argument( "--resume-from", default=None, help="Path to a checkpoint to resume training from (if provided)." ) group.add_argument( "--checkpoint-every", type=int, default=2000, help="Serialize model to a checkpoint after every these many iterations.", ) group.add_argument( "--log-every", type=int, default=20, help="""Log training curves to tensorboard after every these many iterations only master process logs averaged loss values across processes.""", ) # fmt: on def main(_A: argparse.Namespace): if _A.num_gpus_per_machine == 0: # Set device as CPU if num_gpus_per_machine = 0. device: Any = torch.device("cpu") else: # Get the current device as set for current distributed process. # Check `launch` function in `virtex.utils.distributed` module. device = torch.cuda.current_device() # Create a config object (this will be immutable) and perform common setup # such as logging and setting up serialization directory. _C = Config(_A.config, _A.config_override) common_setup(_C, _A) # ------------------------------------------------------------------------- # INSTANTIATE DATALOADER, MODEL, OPTIMIZER, SCHEDULER # ------------------------------------------------------------------------- train_dataset = PretrainingDatasetFactory.from_config(_C, split="train") val_dataset = PretrainingDatasetFactory.from_config(_C, split="val") # Make `DistributedSampler`s to shard datasets across GPU processes. # Skip this if training on CPUs. train_sampler = ( DistributedSampler(train_dataset, shuffle=True) # type: ignore if _A.num_gpus_per_machine > 0 else None ) val_sampler = ( DistributedSampler(val_dataset, shuffle=False) # type: ignore if _A.num_gpus_per_machine > 0 else None ) train_dataloader = DataLoader( train_dataset, batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(), sampler=train_sampler, shuffle=train_sampler is None, num_workers=_A.cpu_workers, pin_memory=True, drop_last=True, collate_fn=train_dataset.collate_fn, ) val_dataloader = DataLoader( val_dataset, batch_size=_C.OPTIM.BATCH_SIZE // dist.get_world_size(), sampler=val_sampler, shuffle=False, num_workers=_A.cpu_workers, pin_memory=True, drop_last=False, collate_fn=val_dataset.collate_fn, ) model = PretrainingModelFactory.from_config(_C).to(device) optimizer = OptimizerFactory.from_config(_C, model.named_parameters()) scheduler = LRSchedulerFactory.from_config(_C, optimizer) # ------------------------------------------------------------------------- # BEFORE TRAINING STARTS # ------------------------------------------------------------------------- # Create a gradient scaler for automatic mixed precision. scaler = amp.GradScaler(enabled=_C.AMP) # Load checkpoint to resume training if specified. if _A.resume_from is not None: start_iteration = CheckpointManager( model=model, optimizer=optimizer, scheduler=scheduler, scaler=scaler, ).load(_A.resume_from) else: start_iteration = 0 # Create an iterator from dataloader to sample batches perpetually. train_dataloader_iter = cycle(train_dataloader, device, start_iteration) # Wrap model in DDP if using more than one processes. if dist.get_world_size() > 1: dist.synchronize() model = nn.parallel.DistributedDataParallel(model, device_ids=[device]) # Keep track of time per iteration and ETA. timer = Timer( start_from=start_iteration + 1, total_iterations=_C.OPTIM.NUM_ITERATIONS ) # Create tensorboard writer and checkpoint manager (only in master process). if dist.is_master_process(): tensorboard_writer = SummaryWriter(log_dir=_A.serialization_dir) tensorboard_writer.add_text("config", f"```\n{_C}\n```") checkpoint_manager = CheckpointManager( _A.serialization_dir, model=model, optimizer=optimizer, scheduler=scheduler, scaler=scaler, ) # ------------------------------------------------------------------------- # TRAINING LOOP # ------------------------------------------------------------------------- for iteration in range(start_iteration + 1, _C.OPTIM.NUM_ITERATIONS + 1): timer.tic() optimizer.zero_grad() batch = next(train_dataloader_iter) with amp.autocast(enabled=_C.AMP): output_dict = model(batch) loss = output_dict["loss"] scaler.scale(loss).backward() # First clip norm of gradients, and then perform optimizer step. scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), _C.OPTIM.CLIP_GRAD_NORM) scaler.step(optimizer) scaler.update() scheduler.step() timer.toc() # --------------------------------------------------------------------- # LOGGING # --------------------------------------------------------------------- if iteration % _A.log_every == 0: logger.info( f"{timer.stats} [Loss {loss:.3f}] [GPU {dist.gpu_mem_usage()} MB]" ) if dist.is_master_process(): tensorboard_writer.add_scalars( "learning_rate", { "visual": optimizer.param_groups[0]["lr"], "common": optimizer.param_groups[-1]["lr"], }, iteration, ) tensorboard_writer.add_scalars( "train", output_dict["loss_components"], iteration ) # --------------------------------------------------------------------- # VALIDATION # --------------------------------------------------------------------- if iteration % _A.checkpoint_every == 0: if dist.is_master_process(): checkpoint_manager.step(iteration) # All processes will wait till master process is done serializing. dist.synchronize() torch.set_grad_enabled(False) model.eval() # Accumulate different val loss components according to the type of # pretraining model. val_loss_counter: Counter = Counter() for val_iteration, val_batch in enumerate(val_dataloader, start=1): for key in val_batch: val_batch[key] = val_batch[key].to(device) output_dict = model(val_batch) val_loss_counter.update(output_dict["loss_components"]) # Divide each loss component by number of val batches per GPU. val_loss_dict = { k: v / val_iteration for k, v in dict(val_loss_counter).items() } dist.average_across_processes(val_loss_dict) torch.set_grad_enabled(True) model.train() logger.info(f"Iteration: {iteration} [Val loss: {val_loss_dict}]") if dist.is_master_process(): tensorboard_writer.add_scalars("val", val_loss_dict, iteration) if __name__ == "__main__": _A = parser.parse_args() if _A.num_gpus_per_machine == 0: main(_A) else: # This will launch `main` and set appropriate CUDA device (GPU ID) as # per process (accessed in the beginning of `main`). dist.launch( main, num_machines=_A.num_machines, num_gpus_per_machine=_A.num_gpus_per_machine, machine_rank=_A.machine_rank, dist_url=_A.dist_url, args=(_A, ), ) ================================================ FILE: setup.py ================================================ #!/usr/bin/env python import glob import os from setuptools import setup import shutil from typing import List def get_model_zoo_configs() -> List[str]: """ Return a list of configs to include in package for model zoo. Copy over these configs inside virtex/model_zoo. """ # Use absolute paths while symlinking. source_configs_dir = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs" ) destination = os.path.join( os.path.dirname(os.path.realpath(__file__)), "virtex", "model_zoo", "configs" ) # Symlink the config directory inside package to have a cleaner pip install. # Remove stale symlink/directory from a previous build. if os.path.exists(source_configs_dir): if os.path.islink(destination): os.unlink(destination) elif os.path.isdir(destination): shutil.rmtree(destination) if not os.path.exists(destination): try: os.symlink(source_configs_dir, destination) except OSError: # Fall back to copying if symlink fails: ex. on Windows. shutil.copytree(source_configs_dir, destination) config_paths = glob.glob("configs/**/*.yaml", recursive=True) return config_paths setup( name="virtex", version="1.4.0", author="Karan Desai and Justin Johnson", description="VirTex: Learning Visual Representations with Textual Annotations", package_data={"virtex.model_zoo": get_model_zoo_configs()}, python_requires=">=3.8", license="MIT", zip_safe=True, ) ================================================ FILE: virtex/__init__.py ================================================ ================================================ FILE: virtex/config.py ================================================ from typing import Any, List, Optional from fvcore.common.config import CfgNode as CN class Config: r""" This class provides package-wide configuration management. It is a nested dict-like structure with nested keys accessible as attributes. It contains sensible default values, which can be modified by (first) a YAML file and (second) a list of attributes and values. An instantiated object is immutable: modifying any attribute is illegal. You must override required parameter values either through ``config_file`` or ``override_list`` arguments. Args: config_file: Path to a YAML file containing config parameters. config_override: A list of sequential attributes and values of parameters. This happens after overriding from YAML file. Examples: Let a YAML file named "config.yaml" specify these parameters to override:: OPTIM: BATCH_SIZE: 512 LR: 0.01 >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 1024]) >>> _C.LR # default: 0.001 0.01 >>> _C.OPTIM.BATCH_SIZE # default: 256, file: 512 1024 """ def __init__( self, config_file: Optional[str] = None, override_list: List[Any] = [] ): _C = CN() # Random seed for NumPy and PyTorch, important for reproducibility. _C.RANDOM_SEED = 0 # Train with Automatic Mixed Precision (native PyTorch). _C.AMP = True # Set CUDNN deterministic flag (torch.backends.cudnn.deterministic). # Setting this will ensure exact results on every run at the cost of # little slowdown. Good for debugging. _C.CUDNN_DETERMINISTIC = False # Set CUDNN benchmark flag (torch.backends.cudnn.benchmark). Enables # CUDNN to select fastest implementation for operations based on GPU. # May change results (in decimals) on different hardware, but faster # to train. Turn off while debugging. _C.CUDNN_BENCHMARK = True # --------------------------------------------------------------------- # Data paths and parameters related to dataloading. # --------------------------------------------------------------------- _C.DATA = CN() # Path to the dataset root, which structure as per README. Path is # assumed to be relative to project root. _C.DATA.ROOT = "datasets/coco" # Path to .model file generated by ``sentencepiece``. _C.DATA.TOKENIZER_MODEL = "datasets/vocab/coco_10k.model" # Handy config params for vocab size and indices of special tokens. # While these can be picked up from the tokenizer, having these in # the config makes it easy to create a model without instantiating too # many tokenizer instances (especially when not needed, e.g. model zoo). # These must match according to what's present in ``TOKENIZER_VOCAB`` # and ``TOKENIZER_MODEL`` above. _C.DATA.VOCAB_SIZE = 10000 # Index of out-of-vocabulary (and padding) token. _C.DATA.UNK_INDEX = 0 # Index of the start-of-sentence [SOS] token. _C.DATA.SOS_INDEX = 1 # Index of the end-of-sentence [EOS] token. _C.DATA.EOS_INDEX = 2 # Index of the word masking token. While not used for captioning, having # this extra token makes it possible to train an MLM model without # re-creating a new vocab mapping. _C.DATA.MASK_INDEX = 3 # Size of the image (square) to crop from original input image. _C.DATA.IMAGE_CROP_SIZE = 224 # Maximum length of input caption (number of tokens). # Longer captions will be truncated up to this length. _C.DATA.MAX_CAPTION_LENGTH = 30 # List of image transforms (pre-processing and data augmentation) to be # applied sequentially (always or randomly) during training and # validation. Refer ``virtex/facetories.py`` for all possible transforms. _C.DATA.IMAGE_TRANSFORM_TRAIN = [ "random_resized_crop", "horizontal_flip", "color_jitter", "normalize", ] _C.DATA.IMAGE_TRANSFORM_VAL = [ "smallest_resize", "center_crop", "normalize", ] # Hyper-parameters for masked LM pretraining task. These are only used # when ``MODEL.NAME`` is "masked_lm". _C.DATA.MASKED_LM = CN() # Fraction of tokens to choose for masking, this must be less than 1. _C.DATA.MASKED_LM.MASK_PROPORTION = 0.15 # Probability to replace chosen tokens with [MASK] token. _C.DATA.MASKED_LM.MASK_PROBABILITY = 0.85 # Probability to replace chosen tokens with a random token. _C.DATA.MASKED_LM.REPLACE_PROBABILITY = 0.10 # --------------------------------------------------------------------- # Model architecture: visual backbone and textual head. # --------------------------------------------------------------------- _C.MODEL = CN() # Name of model, based on pretraining task. # Possible choices: {"token_classification", "multilabel_classification", # "captioning", "bicaptioning", "masked_lm", "virtex"} _C.MODEL.NAME = "virtex" _C.MODEL.VISUAL = CN() # Name of visual backbone. Possible choices: {"blind", "torchvision"} # Models from torchvision can be specified as shown below. _C.MODEL.VISUAL.NAME = "torchvision::resnet50" # Number of channels in pooled spatial features of visual backbone. _C.MODEL.VISUAL.FEATURE_SIZE = 2048 # Whether to load ImageNet pretrained weights into visual backbone. _C.MODEL.VISUAL.PRETRAINED = False # Whether to keep visual backbone frozen and train only textual head. _C.MODEL.VISUAL.FROZEN = False _C.MODEL.TEXTUAL = CN() # Name of textual head. Set to "none" for MODEL.NAME = "*_classification". # Possible choices: {"transdec_postnorm", "transdec_prenorm"}. # Architectural hyper-parameters are specified as shown above. _C.MODEL.TEXTUAL.NAME = "transdec_postnorm::L1_H2048_A32_F8192" # L = Number of layers in the transformer. # H = Hidden size of the transformer (embeddings, attention features). # A = Number of attention heads in the transformer. # F = Size of feedforward layers in the transformer. # Typically, we have (A = H / 64) and (F = 4 * H). # Dropout probability for embedding, hidden features in textual head. _C.MODEL.TEXTUAL.DROPOUT = 0.1 _C.MODEL.DECODER = CN() # What algorithm to use for decoding. Supported values: {"beam_search", # "nucleus_sampling"}. _C.MODEL.DECODER.NAME = "beam_search" # Number of beams to decode (1 = greedy decoding). Ignored when decoding # through nucleus sampling. _C.MODEL.DECODER.BEAM_SIZE = 5 # Size of nucleus for sampling predictions. Ignored when decoding through # beam search. _C.MODEL.DECODER.NUCLEUS_SIZE = 0.9 # Maximum length of decoded caption. Decoding may end earlier when [EOS] # token is sampled. _C.MODEL.DECODER.MAX_DECODING_STEPS = _C.DATA.MAX_CAPTION_LENGTH # --------------------------------------------------------------------- # Optimization hyper-parameters, default values are for pretraining # our best model on bicaptioning task (COCO Captions). # --------------------------------------------------------------------- _C.OPTIM = CN() # Name of optimizer to use. Supported values: {"sgd", "adamw"}. # AdamW uses default (beta1, beta2) values from PyTorch. _C.OPTIM.OPTIMIZER_NAME = "sgd" # Momentum co-efficient for SGD. Ignored for AdamW. _C.OPTIM.SGD_MOMENTUM = 0.9 # Weight decay co-efficient for the optimizer. _C.OPTIM.WEIGHT_DECAY = 0.0001 # Regex pattern of params for which there will be no weight decay. _C.OPTIM.NO_DECAY = ".*textual.(embedding|transformer).*(norm.*|bias)" # Max gradient norm for clipping to avoid exploding gradients. _C.OPTIM.CLIP_GRAD_NORM = 10.0 # Wrap our optimizer with Lookahead (https://arxiv.org/abs/1907.08610). _C.OPTIM.LOOKAHEAD = CN() _C.OPTIM.LOOKAHEAD.USE = True _C.OPTIM.LOOKAHEAD.ALPHA = 0.5 _C.OPTIM.LOOKAHEAD.STEPS = 5 # We set different learning rates for CNN (visual backbone) and rest of # the model. CNN LR is typically much higher for training from scratch. # Both LRs undergo same warmup-decay schedules. # Total batch size (will be distributed evenly across GPUs). _C.OPTIM.BATCH_SIZE = 256 # Max learning rate for CNN (visual backbone). _C.OPTIM.CNN_LR = 0.2 # Max learning rate for rest of the model. _C.OPTIM.LR = 0.001 # Number of iterations to train for, batches are randomly sampled. _C.OPTIM.NUM_ITERATIONS = 500000 # Number of steps at the start of training for linear LR warmup. _C.OPTIM.WARMUP_STEPS = 10000 # Learning rate annealing schedule for decay after warmup. # Possible choices: {"none", "linear", "cosine", "multistep"}. _C.OPTIM.LR_DECAY_NAME = "cosine" # Steps to decay LR for "multistep" schedule. _C.OPTIM.LR_STEPS = [] # Factor to multiply with LR for "multistep" schedule. _C.OPTIM.LR_GAMMA = 0.1 # Override parameter values from YAML file first, then from override # list, then add derived params. self._C = _C if config_file is not None: self._C.merge_from_file(config_file) self._C.merge_from_list(override_list) # Make an instantiated object of this class immutable. self._C.freeze() def dump(self, file_path: str): r"""Save config at the specified file path. Args: file_path: Path to save config file (YAML). """ self._C.dump(stream=open(file_path, "w")) def __getattr__(self, attr: str): return self._C.__getattr__(attr) def __str__(self): return self._C.__str__() def __repr__(self): return self._C.__repr__() ================================================ FILE: virtex/data/__init__.py ================================================ from .datasets.captioning import CaptioningDataset from .datasets.classification import ( TokenClassificationDataset, MultiLabelClassificationDataset, ) from .datasets.masked_lm import MaskedLmDataset from .datasets.downstream import ( ImageNetDataset, INaturalist2018Dataset, VOC07ClassificationDataset, ImageDirectoryDataset, ) __all__ = [ "CocoCaptionsDataset", "CaptioningDataset", "TokenClassificationDataset", "MultiLabelClassificationDataset", "MaskedLmDataset", "ImageDirectoryDataset", "ImageNetDataset", "INaturalist2018Dataset", "VOC07ClassificationDataset", ] ================================================ FILE: virtex/data/datasets/captioning.py ================================================ import random from typing import Callable, Dict, List import numpy as np import torch from torch.utils.data import Dataset from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.data import transforms as T from .coco_captions import CocoCaptionsDataset class CaptioningDataset(Dataset): r""" A dataset which provides image-caption (forward and backward) pairs from a COCO Captions annotation file. This is used for pretraining tasks which use captions - bicaptioning, forward captioning and token classification. Args: data_root: Path to dataset directory containing images and annotations. split: Name of COCO 2017 split to read. One of ``{"train", "val"}``. tokenizer: Tokenizer which maps word tokens to their integer IDs. image_transform: List of image transformations, from either `albumentations `_ or :mod:`virtex.data.transforms`. max_caption_length: Maximum number of tokens to keep in caption tokens. Extra tokens will be trimmed from the right end of the token list. """ def __init__( self, data_root: str, split: str, tokenizer: SentencePieceBPETokenizer, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, max_caption_length: int = 30, ): self._dset = CocoCaptionsDataset(data_root, split) self.tokenizer = tokenizer self.image_transform = image_transform self.max_caption_length = max_caption_length # Short handles for common tokens for convenience: self.padding_idx = tokenizer.token_to_id("") self.sos_id = tokenizer.token_to_id("[SOS]") self.eos_id = tokenizer.token_to_id("[EOS]") def __len__(self): return len(self._dset) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: # keys: {"image_id", "image", "captions"} instance = self._dset[idx] image_id, image, captions = ( instance["image_id"], instance["image"], instance["captions"], ) caption = random.choice(captions) # Transform image-caption pair and convert image from HWC to CHW format. # Pass in caption to image_transform due to paired horizontal flip. # Caption won't be tokenized/processed here. image_caption = self.image_transform(image=image, caption=caption) image, caption = image_caption["image"], image_caption["caption"] image = np.transpose(image, (2, 0, 1)) caption_tokens = [self.sos_id, *self.tokenizer.encode(caption), self.eos_id] caption_tokens = caption_tokens[: self.max_caption_length] return { "image_id": torch.tensor(image_id, dtype=torch.long), "image": torch.tensor(image, dtype=torch.float), "caption_tokens": torch.tensor(caption_tokens, dtype=torch.long), "noitpac_tokens": torch.tensor(caption_tokens, dtype=torch.long).flip(0), "caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long), } def collate_fn( self, data: List[Dict[str, torch.Tensor]] ) -> Dict[str, torch.Tensor]: # Pad `caption_tokens` and `masked_labels` up to this length. caption_tokens = torch.nn.utils.rnn.pad_sequence( [d["caption_tokens"] for d in data], batch_first=True, padding_value=self.padding_idx, ) noitpac_tokens = torch.nn.utils.rnn.pad_sequence( [d["noitpac_tokens"] for d in data], batch_first=True, padding_value=self.padding_idx, ) return { "image_id": torch.stack([d["image_id"] for d in data], dim=0), "image": torch.stack([d["image"] for d in data], dim=0), "caption_tokens": caption_tokens, "noitpac_tokens": noitpac_tokens, "caption_lengths": torch.stack([d["caption_lengths"] for d in data]), } ================================================ FILE: virtex/data/datasets/classification.py ================================================ from collections import defaultdict import glob import json import os import random from typing import Any, Callable, Dict, List, Tuple import albumentations as alb import cv2 import numpy as np import torch from torch.utils.data import Dataset from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.data import transforms as T from .coco_captions import CocoCaptionsDataset class TokenClassificationDataset(Dataset): r""" A dataset which provides image-labelset pairs from a COCO Captions annotation file. The set of caption tokens (unordered) is treated as a labelset. Args: data_root: Path to dataset directory containing images and annotations. split: Name of COCO 2017 split to read. One of ``{"train", "val"}``. tokenizer: Tokenizer which maps word tokens to their integer IDs. image_transform: List of image transformations, from either `albumentations `_ or :mod:`virtex.data.transforms`. max_caption_length: Maximum number of tokens to keep in caption tokens. Extra tokens will be trimmed from the right end of the token list. """ def __init__( self, data_root: str, split: str, tokenizer: SentencePieceBPETokenizer, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, max_caption_length: int = 30, ): self._dset = CocoCaptionsDataset(data_root, split) self.image_transform = image_transform self.max_caption_length = max_caption_length # Short handles for common tokens for convenience: self.padding_idx = tokenizer.token_to_id("") self.sos_id = tokenizer.token_to_id("[SOS]") self.eos_id = tokenizer.token_to_id("[EOS]") def __len__(self): return len(self._dset) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: # keys: {"image_id", "image", "captions"} instance = self._dset[idx] image_id, image, captions = ( instance["image_id"], instance["image"], instance["captions"], ) caption = random.choice(captions) # Transform image-caption pair and convert image from HWC to CHW format. # Pass in caption to image_transform due to paired horizontal flip. # Caption won't be tokenized/processed here. image_caption = self.image_transform(image=image, caption=caption) image, caption = image_caption["image"], image_caption["caption"] image = np.transpose(image, (2, 0, 1)) caption_tokens = [self.sos_id, *self.tokenizer.encode(caption), self.eos_id] caption_tokens = caption_tokens[: self.max_caption_length] return { "image_id": torch.tensor(image_id, dtype=torch.long), "image": torch.tensor(image, dtype=torch.float), "labels": torch.tensor(caption_tokens, dtype=torch.long), } def collate_fn( self, data: List[Dict[str, torch.Tensor]] ) -> Dict[str, torch.Tensor]: labels = torch.nn.utils.rnn.pad_sequence( [d["labels"] for d in data], batch_first=True, padding_value=self.padding_idx, ) return { "image_id": torch.stack([d["image_id"] for d in data], dim=0), "image": torch.stack([d["image"] for d in data], dim=0), "labels": labels, } class MultiLabelClassificationDataset(Dataset): r""" A dataset which provides image-labelset pairs from COCO instance annotation files. This is used for multilabel classification pretraining task. Args: data_root: Path to dataset directory containing images and annotations. split: Name of COCO 2017 split to read. One of ``{"train", "val"}``. image_transform: List of image transformations, from either `albumentations `_ or :mod:`virtex.data.transforms`. """ def __init__( self, data_root: str, split: str, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, ): self.image_transform = image_transform # Make a tuple of image id and its filename, get image_id from its # filename (assuming directory has images with names in COCO 2017 format). image_filenames = glob.glob(os.path.join(data_root, f"{split}2017", "*.jpg")) self.id_filename: List[Tuple[int, str]] = [ (int(os.path.basename(name)[:-4]), name) for name in image_filenames ] # Load the instance (bounding box and mask) annotations. _annotations = json.load( open(os.path.join(data_root, "annotations", f"instances_{split}2017.json")) ) # Make a mapping between COCO category id and its index, to make IDs # consecutive, else COCO has 80 classes with IDs 1-90. Start index from 1 # as 0 is reserved for background (padding idx). _category_ids = { ann["id"]: index + 1 for index, ann in enumerate(_annotations["categories"]) } # Mapping from image ID to list of unique category IDs (indices as above) # in corresponding image. self._labels: Dict[str, Any] = defaultdict(list) for ann in _annotations["annotations"]: self._labels[ann["image_id"]].append(_category_ids[ann["category_id"]]) # De-duplicate and drop empty labels, we only need to do classification. self._labels = { _id: list(set(lbl)) for _id, lbl in self._labels.items() if len(lbl) > 0 } # Filter out image IDs which didn't have any labels. self.id_filename = [ (t[0], t[1]) for t in self.id_filename if t[0] in self._labels ] # Padding while forming a batch, because images may have variable number # of instances. We do not need padding index from tokenizer: COCO has # category ID 0 as background, conventionally. self.padding_idx = 0 def __len__(self): return len(self.id_filename) def __getitem__(self, idx: int): # Get image ID and filename. image_id, filename = self.id_filename[idx] # Open image from path and apply transformation, convert to CHW format. image = cv2.imread(filename) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = self.image_transform(image=image)["image"] image = np.transpose(image, (2, 0, 1)) # Get a list of instances present in the image. labels = self._labels[image_id] return { "image_id": torch.tensor(image_id, dtype=torch.long), "image": torch.tensor(image, dtype=torch.float), "labels": torch.tensor(labels, dtype=torch.long), } def collate_fn( self, data: List[Dict[str, torch.Tensor]] ) -> Dict[str, torch.Tensor]: labels = torch.nn.utils.rnn.pad_sequence( [d["labels"] for d in data], batch_first=True, padding_value=self.padding_idx, ) return { "image_id": torch.stack([d["image_id"] for d in data], dim=0), "image": torch.stack([d["image"] for d in data], dim=0), "labels": labels, } ================================================ FILE: virtex/data/datasets/coco_captions.py ================================================ from collections import defaultdict import json import os import unicodedata from typing import Dict, List import cv2 from torch.utils.data import Dataset class CocoCaptionsDataset(Dataset): r""" A PyTorch dataset to read COCO Captions dataset and provide it completely unprocessed. This dataset is used by various task-specific datasets in :mod:`~virtex.data.datasets` module. Args: data_root: Path to the COCO dataset root directory. split: Name of COCO 2017 split to read. One of ``{"train", "val"}``. """ def __init__(self, data_root: str, split: str): # Get paths to image directory and annotation file. image_dir = os.path.join(data_root, f"{split}2017") captions = json.load( open(os.path.join(data_root, "annotations", f"captions_{split}2017.json")) ) # Collect list of captions for each image. captions_per_image: Dict[int, List[str]] = defaultdict(list) for ann in captions["annotations"]: # Perform common normalization (lowercase, trim spaces, NKFC strip # accents and NKFC normalization). caption = ann["caption"].lower() caption = unicodedata.normalize("NFKD", caption) caption = "".join([chr for chr in caption if not unicodedata.combining(chr)]) captions_per_image[ann["image_id"]].append(caption) # Collect image file for each image (by its ID). image_filepaths: Dict[int, str] = { im["id"]: os.path.join(image_dir, im["file_name"]) for im in captions["images"] } # Keep all annotations in memory. Make a list of tuples, each tuple # is ``(image_id, file_path, list[captions])``. self.instances = [ (im_id, image_filepaths[im_id], captions_per_image[im_id]) for im_id in captions_per_image.keys() ] def __len__(self): return len(self.instances) def __getitem__(self, idx: int): image_id, image_path, captions = self.instances[idx] # shape: (height, width, channels), dtype: uint8 image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return {"image_id": image_id, "image": image, "captions": captions} ================================================ FILE: virtex/data/datasets/downstream.py ================================================ from collections import defaultdict import glob import json import os from typing import Callable, Dict, List, Tuple import cv2 import numpy as np import torch from torch.utils.data import Dataset from torchvision.datasets import ImageNet from virtex.data import transforms as T class ImageNetDataset(ImageNet): r""" Simple wrapper over torchvision's ImageNet dataset. Image transform is handled here instead of passing to super class. Args: data_root: Path to the ImageNet dataset directory. split: Which split to read from. One of ``{"train", "val"}``. image_transform: List of image transformations, from either `albumentations `_ or :mod:`virtex.data.transforms`. """ def __init__( self, data_root: str = "datasets/imagenet", split: str = "train", image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, ): super().__init__(data_root, split) self.image_transform = image_transform def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: image, label = super().__getitem__(idx) # Apply transformation to image and convert to CHW format. image = self.image_transform(image=np.array(image))["image"] image = np.transpose(image, (2, 0, 1)) return { "image": torch.tensor(image, dtype=torch.float), "label": torch.tensor(label, dtype=torch.long), } @staticmethod def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: return { "image": torch.stack([d["image"] for d in data], dim=0), "label": torch.stack([d["label"] for d in data], dim=0), } class INaturalist2018Dataset(Dataset): r""" A dataset which provides image-label pairs from the iNaturalist 2018 dataset. Args: data_root: Path to the iNaturalist 2018 dataset directory. split: Which split to read from. One of ``{"train", "val"}``. image_transform: List of image transformations, from either `albumentations `_ or :mod:`virtex.data.transforms`. """ def __init__( self, data_root: str = "datasets/inaturalist", split: str = "train", image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, ): self.split = split self.image_transform = image_transform annotations = json.load( open(os.path.join(data_root, "annotations", f"{split}2018.json")) ) # Make a list of image IDs to file paths. self.image_id_to_file_path = { ann["id"]: os.path.join(data_root, ann["file_name"]) for ann in annotations["images"] } # For a list of instances: (image_id, category_id) tuples. self.instances = [ (ann["image_id"], ann["category_id"]) for ann in annotations["annotations"] ] def __len__(self): return len(self.instances) def __getitem__(self, idx: int): image_id, label = self.instances[idx] image_path = self.image_id_to_file_path[image_id] # Open image from path and apply transformation, convert to CHW format. image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = self.image_transform(image=image)["image"] image = np.transpose(image, (2, 0, 1)) return { "image": torch.tensor(image, dtype=torch.float), "label": torch.tensor(label, dtype=torch.long), } @staticmethod def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: return { "image": torch.stack([d["image"] for d in data], dim=0), "label": torch.stack([d["label"] for d in data], dim=0), } class VOC07ClassificationDataset(Dataset): r""" A dataset which provides image-label pairs from the PASCAL VOC 2007 dataset. Args: data_root: Path to VOC 2007 directory containing sub-directories named ``Annotations``, ``ImageSets``, and ``JPEGImages``. split: Which split to read from. One of ``{"trainval", "test"}``. image_transform: List of image transformations, from either `albumentations `_ or :mod:`virtex.data.transforms`. """ def __init__( self, data_root: str = "datasets/VOC2007", split: str = "trainval", image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, ): self.split = split self.image_transform = image_transform ann_paths = sorted( glob.glob(os.path.join(data_root, "ImageSets", "Main", f"*_{split}.txt")) ) # A list like; ["aeroplane", "bicycle", "bird", ...] self.class_names = [ os.path.basename(path).split("_")[0] for path in ann_paths ] # We will construct a map for image name to a list of # shape: (num_classes, ) and values as one of {-1, 0, 1}. # 1: present, -1: not present, 0: ignore. image_names_to_labels: Dict[str, torch.Tensor] = defaultdict( lambda: -torch.ones(len(self.class_names), dtype=torch.int32) ) for cls_num, ann_path in enumerate(ann_paths): with open(ann_path, "r") as fopen: for line in fopen: img_name, orig_label_str = line.strip().split() orig_label = int(orig_label_str) # In VOC data, -1 (not present): set to 0 as train target # In VOC data, 0 (ignore): set to -1 as train target. orig_label = ( 0 if orig_label == -1 else -1 if orig_label == 0 else 1 ) image_names_to_labels[img_name][cls_num] = orig_label # Convert the dict to a list of tuples for easy indexing. # Replace image name with full image path. self.instances: List[Tuple[str, torch.Tensor]] = [ ( os.path.join(data_root, "JPEGImages", f"{image_name}.jpg"), label.tolist(), ) for image_name, label in image_names_to_labels.items() ] def __len__(self): return len(self.instances) def __getitem__(self, idx: int): image_path, label = self.instances[idx] # Open image from path and apply transformation, convert to CHW format. image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = self.image_transform(image=image)["image"] image = np.transpose(image, (2, 0, 1)) return { "image": torch.tensor(image, dtype=torch.float), "label": torch.tensor(label, dtype=torch.long), } @staticmethod def collate_fn(data: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: return { "image": torch.stack([d["image"] for d in data], dim=0), "label": torch.stack([d["label"] for d in data], dim=0), } class ImageDirectoryDataset(Dataset): r""" A dataset which reads images from any directory. This class is useful to run image captioning inference on our models with any arbitrary images. Args: data_root: Path to a directory containing images. image_transform: List of image transformations, from either `albumentations `_ or :mod:`virtex.data.transforms`. """ def __init__( self, data_root: str, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM ): self.image_paths = glob.glob(os.path.join(data_root, "*")) self.image_transform = image_transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx: int): image_path = self.image_paths[idx] # Remove extension from image name to use as image_id. image_id = os.path.splitext(os.path.basename(image_path))[0] # Open image from path and apply transformation, convert to CHW format. image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = self.image_transform(image=image)["image"] image = np.transpose(image, (2, 0, 1)) # Return image id as string so collate_fn does not cast to torch.tensor. return {"image_id": str(image_id), "image": torch.tensor(image)} ================================================ FILE: virtex/data/datasets/masked_lm.py ================================================ import math import random from typing import Callable, Dict, List import albumentations as alb import numpy as np import torch from torch.utils.data import Dataset from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.data import transforms as T from .coco_captions import CocoCaptionsDataset class MaskedLmDataset(Dataset): def __init__( self, data_root: str, split: str, tokenizer: SentencePieceBPETokenizer, image_transform: Callable = T.DEFAULT_IMAGE_TRANSFORM, max_caption_length: int = 30, mask_proportion: float = 0.15, mask_probability: float = 0.80, replace_probability: float = 0.10, ): self._dset = CocoCaptionsDataset(data_root, split) self.tokenizer = tokenizer self.image_transform = image_transform self.max_caption_length = max_caption_length # Short handles for common tokens for convenience: self.padding_idx = tokenizer.token_to_id("") self.sos_id = tokenizer.token_to_id("[SOS]") self.eos_id = tokenizer.token_to_id("[EOS]") self.mask_id = tokenizer.token_to_id("[MASK]") self._vocab_size = tokenizer.get_vocab_size() self._mask_proportion = mask_proportion self._mask_prob = mask_probability self._repl_prob = replace_probability def __len__(self): return len(self._dset) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: # keys: {"image_id", "image", "captions"} instance = self._dset[idx] image_id, image, captions = ( instance["image_id"], instance["image"], instance["captions"], ) caption = random.choice(captions) # Transform image-caption pair and convert image from HWC to CHW format. # Pass in caption to image_transform due to paired horizontal flip. # Caption won't be tokenized/processed here. image_caption = self.image_transform(image=image, caption=caption) image, caption = image_caption["image"], image_caption["caption"] image = np.transpose(image, (2, 0, 1)) caption_tokens = [self.sos_id, *self.tokenizer.encode(caption), self.eos_id] caption_tokens = caption_tokens[: self.max_caption_length] # --------------------------------------------------------------------- # Mask some tokens randomly. # --------------------------------------------------------------------- masked_labels = [self.padding_idx] * len(caption_tokens) # Indices in `caption_tokens` list to mask (minimum 1 index). # Leave out first and last indices (boundary tokens). tokens_to_mask: List[int] = random.sample( list(range(1, len(caption_tokens) - 1)), math.ceil((len(caption_tokens) - 2) * self._mask_proportion), ) for i in tokens_to_mask: # Whether to replace with [MASK] or random word. # If only one token, always [MASK]. if len(tokens_to_mask) == 1: masked_labels[i] = caption_tokens[i] caption_tokens[i] = self.mask_id else: _flag: float = random.random() if _flag <= self._mask_prob + self._repl_prob: if _flag <= self._mask_prob: masked_labels[i] = caption_tokens[i] caption_tokens[i] = self.mask_id else: caption_tokens[i] = self._random_token_index() # --------------------------------------------------------------------- return { "image_id": torch.tensor(image_id, dtype=torch.long), "image": torch.tensor(image, dtype=torch.float), "caption_tokens": torch.tensor(caption_tokens, dtype=torch.long), "masked_labels": torch.tensor(masked_labels, dtype=torch.long), "caption_lengths": torch.tensor(len(caption_tokens), dtype=torch.long), } def collate_fn( self, data: List[Dict[str, torch.Tensor]] ) -> Dict[str, torch.Tensor]: # Pad `caption_tokens` and `masked_labels` up to this length. caption_tokens = torch.nn.utils.rnn.pad_sequence( [d["caption_tokens"] for d in data], batch_first=True, padding_value=self.padding_idx, ) masked_labels = torch.nn.utils.rnn.pad_sequence( [d["masked_labels"] for d in data], batch_first=True, padding_value=self.padding_idx, ) return { "image_id": torch.stack([d["image_id"] for d in data], dim=0), "image": torch.stack([d["image"] for d in data], dim=0), "caption_tokens": caption_tokens, "masked_labels": masked_labels, "caption_lengths": torch.stack([d["caption_lengths"] for d in data]), } def _random_token_index(self) -> int: return random.randint(0, self._vocab_size - 1) ================================================ FILE: virtex/data/tokenizers.py ================================================ from typing import Any, Dict, List import sentencepiece as sp class SentencePieceBPETokenizer: r""" A tokenizer based on `SentencePiece `_ with BPE sub-routine. It encodes caption strings into list of tokens. Args: model_path: Path to the ``.model`` file trained by SentencePiece. """ SP_SPACE = u"▁" def __init__(self, model_path: str): self.model_path = model_path # Load pretrained tokenizer model. self.model = sp.SentencePieceProcessor() self.model.Load(model_path) def __getstate__(self): r""" This magic method, along with ``__setstate__`` makes an object of this class picklable (and usable while data loading with multiple workers). """ state_dict = self.__dict__.copy() state_dict["model"] = None return state_dict def __setstate__(self, state_dict: Dict[str, Any]): self.__dict__ = state_dict self.model = sp.SentencePieceProcessor() self.model.Load(self.model_path) def get_vocab_size(self) -> int: r"""Return number of tokens in vocabulary (including special tokens).""" return len(self.model) def token_to_id(self, token: str) -> int: r"""Get integer ID of a string token (```` if does not exist).""" # Since tokenizer uses subword regularization, one token may break down to multiple IDs. # Keep trying till we get a single ID. return self.model.piece_to_id(token) def id_to_token(self, token_id: int) -> str: r"""Get string token of an integer ID (```` if does not exist).""" return self.model.id_to_piece(token_id) def encode(self, text: str) -> List[int]: r"""Convert a text string to a list of integer token ids.""" return self.model.EncodeAsIds(text) def decode(self, token_ids: List[int]) -> str: r"""Convert a sequence of token IDs to a text string.""" return self.model.DecodeIds(token_ids) ================================================ FILE: virtex/data/transforms.py ================================================ import albumentations as alb import cv2 class HorizontalFlip(alb.BasicTransform): r""" Flip the image horizontally randomly (equally likely) and replace the word "left" with "right" in the caption. .. note:: This transform can also work on images only (without the captions). Its behavior will be same as albumentations :class:`~albumentations.augmentations.transforms.HorizontalFlip`. Examples: >>> flip = HorizontalFlip(p=0.5) >>> out1 = flip(image=image, caption=caption) # keys: {"image", "caption"} >>> # Also works with images (without caption). >>> out2 = flip(image=image) # keys: {"image"} """ @property def targets(self): return {"image": self.apply, "caption": self.apply_to_caption} def apply(self, img, **params): return cv2.flip(img, 1) def apply_to_caption(self, caption, **params): caption = ( caption.replace("left", "[TMP]") .replace("right", "left") .replace("[TMP]", "right") ) return caption class RandomResizedSquareCrop(alb.RandomResizedCrop): r""" A variant of :class:`albumentations.augmentations.transforms.RandomResizedCrop` which assumes a square crop (width = height). Everything else is same. Args: size: Dimension of the width and height of the cropped image. """ def __init__(self, size: int, *args, **kwargs): super().__init__(height=size, width=size, *args, **kwargs) class CenterSquareCrop(alb.CenterCrop): r""" A variant of :class:`albumentations.augmentations.transforms.CenterCrop` which assumes a square crop (width = height). Everything else is same. Args: size: Dimension of the width and height of the cropped image. """ def __init__(self, size: int, *args, **kwargs): super().__init__(height=size, width=size, *args, **kwargs) class SquareResize(alb.Resize): r""" A variant of :class:`albumentations.augmentations.transforms.Resize` which assumes a square resize (width = height). Everything else is same. Args: size: Dimension of the width and height of the cropped image. """ def __init__(self, size: int, *args, **kwargs): super().__init__(height=size, width=size, *args, **kwargs) # ============================================================================= # SOME COMMON CONSTANTS AND IMAGE TRANSFORMS: # These serve as references here, and are used as default params in many # dataset class constructors. # ----------------------------------------------------------------------------- IMAGENET_COLOR_MEAN = (0.485, 0.456, 0.406) r"""ImageNet color normalization mean in RGB format (values in 0-1).""" IMAGENET_COLOR_STD = (0.229, 0.224, 0.225) r"""ImageNet color normalization std in RGB format (values in 0-1).""" DEFAULT_IMAGE_TRANSFORM = alb.Compose( [ alb.SmallestMaxSize(256, p=1.0), CenterSquareCrop(224, p=1.0), alb.Normalize(mean=IMAGENET_COLOR_MEAN, std=IMAGENET_COLOR_STD, p=1.0), ] ) r"""Default transform without any data augmentation (during pretraining).""" # ============================================================================= ================================================ FILE: virtex/factories.py ================================================ r""" This module is a collection of *factories* for creating objects of datasets, models, optimizers and other useful components. For example, a ResNet-50 visual backbone can be created as: .. code-block:: python >>> # Explicitly by name, args and kwargs: >>> backbone = VisualBackboneFactory.create( ... "torchvision::resnet50", pretrained=False ... ) >>> # Directly from a config object: >>> _C = Config(override_list=["MODEL.VISUAL.NAME", "torchvision::resnet50"]) >>> backbone = VisualBackboneFactory.from_config(_C) Creating directly from :class:`~virtex.config.Config` is fast and simple, and ensures minimal changes throughout the codebase upon any change in the call signature of underlying class; or config hierarchy. Refer description of specific factories for more details. """ import re from functools import partial from typing import Any, Callable, Dict, Iterable, List import albumentations as alb from torch import nn, optim import virtex.data as vdata import virtex.models as vmodels from virtex.config import Config from virtex.data import transforms as T from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.modules import visual_backbones, textual_heads from virtex.optim import Lookahead, lr_scheduler from virtex.utils.beam_search import AutoRegressiveBeamSearch from virtex.utils.nucleus_sampling import AutoRegressiveNucleusSampling class Factory: r""" Base class for all factories. All factories must inherit this base class and follow these guidelines for a consistent behavior: * Factory objects cannot be instantiated, doing ``factory = SomeFactory()`` is illegal. Child classes should not implement ``__init__`` methods. * All factories must have an attribute named ``PRODUCTS`` of type ``Dict[str, Callable]``, which associates each class with a unique string name which can be used to create it. * All factories must implement one classmethod, :meth:`from_config` which contains logic for creating an object directly by taking name and other arguments directly from :class:`~virtex.config.Config`. They can use :meth:`create` already implemented in this base class. * :meth:`from_config` should not use too many extra arguments than the config itself, unless necessary (such as model parameters for optimizer). """ PRODUCTS: Dict[str, Callable] = {} def __init__(self): raise ValueError( f"""Cannot instantiate {self.__class__.__name__} object, use `create` classmethod to create a product from this factory. """ ) @classmethod def create(cls, name: str, *args, **kwargs) -> Any: r"""Create an object by its name, args and kwargs.""" if name not in cls.PRODUCTS: raise KeyError(f"{cls.__class__.__name__} cannot create {name}.") return cls.PRODUCTS[name](*args, **kwargs) @classmethod def from_config(cls, config: Config) -> Any: r"""Create an object directly from config.""" raise NotImplementedError class TokenizerFactory(Factory): r""" Factory to create text tokenizers. This codebase ony supports one tokenizer for now, but having a dedicated factory makes it easy to add more if needed. Possible choices: ``{"SentencePieceBPETokenizer"}``. """ PRODUCTS: Dict[str, Callable] = { "SentencePieceBPETokenizer": SentencePieceBPETokenizer } @classmethod def from_config(cls, config: Config) -> SentencePieceBPETokenizer: r""" Create a tokenizer directly from config. Args: config: Config object with all the parameters. """ _C = config tokenizer = cls.create( "SentencePieceBPETokenizer", model_path=_C.DATA.TOKENIZER_MODEL, ) return tokenizer class ImageTransformsFactory(Factory): r""" Factory to create image transformations for common preprocessing and data augmentations. These are a mix of default transformations from `albumentations `_ and some extended ones defined in :mod:`virtex.data.transforms`. It uses sensible default values, however they can be provided with the name in dict syntax. Example: ``random_resized_crop::{'scale': (0.08, 1.0)}`` .. note:: This factory does not implement :meth:`from_config` method. It is only used by :class:`PretrainingDatasetFactory` and :class:`DownstreamDatasetFactory`. Possible choices: ``{"center_crop", "horizontal_flip", "random_resized_crop", "normalize", "global_resize", "color_jitter", "smallest_resize"}``. """ # fmt: off PRODUCTS: Dict[str, Callable] = { # Input resize transforms: whenever selected, these are always applied. # These transforms require one position argument: image dimension. "random_resized_crop": partial( T.RandomResizedSquareCrop, scale=(0.2, 1.0), ratio=(0.75, 1.333), p=1.0 ), "center_crop": partial(T.CenterSquareCrop, p=1.0), "smallest_resize": partial(alb.SmallestMaxSize, p=1.0), "global_resize": partial(T.SquareResize, p=1.0), # Keep hue limits small in color jitter because it changes color drastically # and captions often mention colors. Apply with higher probability. "color_jitter": partial( alb.ColorJitter, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8 ), "horizontal_flip": partial(T.HorizontalFlip, p=0.5), # Color normalization: whenever selected, always applied. This accepts images # in [0, 255], requires mean and std in [0, 1] and normalizes to `N(0, 1)`. "normalize": partial( alb.Normalize, mean=T.IMAGENET_COLOR_MEAN, std=T.IMAGENET_COLOR_STD, p=1.0 ), } # fmt: on @classmethod def create(cls, name: str, *args, **kwargs) -> Any: r"""Create an object by its name, args and kwargs.""" if "::" in name: name, __kwargs = name.split("::") _kwargs = eval(__kwargs) else: _kwargs = {} _kwargs.update(kwargs) return super().create(name, *args, **_kwargs) @classmethod def from_config(cls, config: Config): r"""Augmentations cannot be created from config, only :meth:`create`.""" raise NotImplementedError class PretrainingDatasetFactory(Factory): r""" Factory to create :class:`~torch.utils.data.Dataset` s for pretraining VirTex models. Datasets are created depending on pretraining task used. Typically these datasets either provide image-caption pairs, or only images from COCO Captions dataset (serialized to an LMDB file). As an exception, the dataset for ``multilabel_classification`` provides COCO images and labels of their bounding box annotations. Possible choices: ``{"bicaptioning", "captioning", "masked_lm", "token_classification", "multilabel_classification"}``. """ PRODUCTS: Dict[str, Callable] = { "virtex": vdata.CaptioningDataset, "bicaptioning": vdata.CaptioningDataset, "captioning": vdata.CaptioningDataset, "masked_lm": vdata.MaskedLmDataset, "token_classification": vdata.TokenClassificationDataset, "multilabel_classification": vdata.MultiLabelClassificationDataset, } @classmethod def from_config(cls, config: Config, split: str = "train"): r""" Create a dataset directly from config. Names in this factory match with names in :class:`PretrainingModelFactory` because both use same config parameter ``MODEL.NAME`` to create objects. Args: config: Config object with all the parameters. split: Which dataset split to load. One of ``{"train", "val"}``. """ _C = config # Every dataset needs these two args. kwargs = {"data_root": _C.DATA.ROOT, "split": split} # Create a list of image transformations based on transform names. image_transform_list: List[Callable] = [] for name in getattr(_C.DATA, f"IMAGE_TRANSFORM_{split.upper()}"): # Pass dimensions if cropping / resizing, else rely on the defaults # as per `ImageTransformsFactory`. if "resize" in name or "crop" in name: image_transform_list.append( ImageTransformsFactory.create(name, _C.DATA.IMAGE_CROP_SIZE) ) else: image_transform_list.append(ImageTransformsFactory.create(name)) kwargs["image_transform"] = alb.Compose(image_transform_list) # Add dataset specific kwargs. if _C.MODEL.NAME != "multilabel_classification": tokenizer = TokenizerFactory.from_config(_C) kwargs.update( tokenizer=tokenizer, max_caption_length=_C.DATA.MAX_CAPTION_LENGTH, ) if _C.MODEL.NAME == "masked_lm": kwargs.update( mask_proportion=_C.DATA.MASKED_LM.MASK_PROPORTION, mask_probability=_C.DATA.MASKED_LM.MASK_PROBABILITY, replace_probability=_C.DATA.MASKED_LM.REPLACE_PROBABILITY, ) # Dataset names match with model names (and ofcourse pretext names). return cls.create(_C.MODEL.NAME, **kwargs) class DownstreamDatasetFactory(Factory): r""" Factory to create :class:`~torch.utils.data.Dataset` s for evaluating VirTex models on downstream tasks. Possible choices: ``{"datasets/VOC2007", "datasets/imagenet"}``. """ PRODUCTS: Dict[str, Callable] = { "datasets/VOC2007": vdata.VOC07ClassificationDataset, "datasets/imagenet": vdata.ImageNetDataset, "datasets/inaturalist": vdata.INaturalist2018Dataset, } @classmethod def from_config(cls, config: Config, split: str = "train"): r""" Create a dataset directly from config. Names in this factory are paths of dataset directories (relative to the project directory), because config parameter ``DATA.ROOT`` is used to create objects. Args: config: Config object with all the parameters. split: Which dataset split to load. One of ``{"trainval", "test"}`` for VOC2007, or one of ``{"train", "val"}`` for ImageNet. """ _C = config # Every dataset needs these two args. kwargs = {"data_root": _C.DATA.ROOT, "split": split} # For VOC2007, `IMAGE_TRANSFORM_TRAIN` is used for "trainval" split and # `IMAGE_TRANSFORM_VAL` is used fo "test" split. image_transform_names: List[str] = list( _C.DATA.IMAGE_TRANSFORM_TRAIN if "train" in split else _C.DATA.IMAGE_TRANSFORM_VAL ) # Create a list of image transformations based on names. image_transform_list: List[Callable] = [] for name in image_transform_names: # Pass dimensions for resize/crop, else rely on the defaults. if name.split("::")[0] in {"random_resized_crop", "center_crop", "global_resize"}: transform = ImageTransformsFactory.create(name, 224) elif name.split("::")[0] in {"smallest_resize"}: transform = ImageTransformsFactory.create(name, 256) else: transform = ImageTransformsFactory.create(name) image_transform_list.append(transform) kwargs["image_transform"] = alb.Compose(image_transform_list) return cls.create(_C.DATA.ROOT, **kwargs) class VisualBackboneFactory(Factory): r""" Factory to create :mod:`~virtex.modules.visual_backbones`. This factory supports any ResNet-like model from `Torchvision `_. Use the method name for model as in torchvision, for example, ``torchvision::resnet50``, ``torchvision::wide_resnet50_2`` etc. Possible choices: ``{"torchvision"}``. """ PRODUCTS: Dict[str, Callable] = { "torchvision": visual_backbones.TorchvisionVisualBackbone, } @classmethod def from_config(cls, config: Config) -> visual_backbones.VisualBackbone: r""" Create a visual backbone directly from config. Args: config: Config object with all the parameters. """ _C = config kwargs = {"visual_feature_size": _C.MODEL.VISUAL.FEATURE_SIZE} if "torchvision" in _C.MODEL.VISUAL.NAME: # Check the name for models from torchvision. cnn_name = _C.MODEL.VISUAL.NAME.split("::")[-1] kwargs["pretrained"] = _C.MODEL.VISUAL.PRETRAINED kwargs["frozen"] = _C.MODEL.VISUAL.FROZEN return cls.create("torchvision", cnn_name, **kwargs) else: return cls.create(_C.MODEL.VISUAL.NAME, **kwargs) class TextualHeadFactory(Factory): r""" Factory to create :mod:`~virtex.modules.textual_heads`. Architectural hyperparameters for transformers can be specified as ``name::*``. For example, ``transdec_postnorm::L1_H1024_A16_F4096`` would create a transformer textual head with ``L = 1`` layers, ``H = 1024`` hidden size, ``A = 16`` attention heads, and ``F = 4096`` size of feedforward layers. Textual head should be ``"none"`` for pretraining tasks which do not involve language modeling, such as ``"token_classification"``. Possible choices: ``{"transdec_postnorm", "transdec_prenorm", "none"}``. """ PRODUCTS: Dict[str, Callable] = { "transdec_prenorm": partial( textual_heads.TransformerDecoderTextualHead, norm_first=True ), "transdec_postnorm": partial( textual_heads.TransformerDecoderTextualHead, norm_first=False ), "none": textual_heads.LinearTextualHead, } @classmethod def from_config(cls, config: Config) -> nn.Module: r""" Create a textual head directly from config. Args: config: Config object with all the parameters. """ _C = config name = _C.MODEL.TEXTUAL.NAME kwargs = { "visual_feature_size": _C.MODEL.VISUAL.FEATURE_SIZE, "vocab_size": _C.DATA.VOCAB_SIZE, } if "trans" in _C.MODEL.TEXTUAL.NAME: # Get architectural hyper-params as per name by matching regex. name, architecture = name.split("::") architecture = re.match(r"L(\d+)_H(\d+)_A(\d+)_F(\d+)", architecture) num_layers = int(architecture.group(1)) hidden_size = int(architecture.group(2)) attention_heads = int(architecture.group(3)) feedforward_size = int(architecture.group(4)) # Mask the future tokens for autoregressive captioning. mask_future = _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning"} kwargs.update( hidden_size=hidden_size, num_layers=num_layers, attention_heads=attention_heads, feedforward_size=feedforward_size, dropout=_C.MODEL.TEXTUAL.DROPOUT, mask_future_positions=mask_future, max_caption_length=_C.DATA.MAX_CAPTION_LENGTH, padding_idx=_C.DATA.UNK_INDEX, ) return cls.create(name, **kwargs) class PretrainingModelFactory(Factory): r""" Factory to create :mod:`~virtex.models` for different pretraining tasks. Possible choices: ``{"bicaptioning", "captioning", "masked_lm", "token_classification", "multilabel_classification"}``. """ PRODUCTS: Dict[str, Callable] = { # First two are basically the same. Added for shorthand notation. "virtex": vmodels.VirTexModel, "bicaptioning": vmodels.BidirectionalCaptioningModel, "captioning": vmodels.ForwardCaptioningModel, "masked_lm": vmodels.MaskedLMModel, "token_classification": vmodels.TokenClassificationModel, "multilabel_classification": vmodels.MultiLabelClassificationModel, } @classmethod def from_config(cls, config: Config) -> nn.Module: r""" Create a model directly from config. Args: config: Config object with all the parameters. """ _C = config # Build visual and textual streams based on config. visual = VisualBackboneFactory.from_config(_C) textual = TextualHeadFactory.from_config(_C) # Add model specific kwargs. Refer call signatures of specific models # for matching kwargs here. if _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning"}: kwargs = { "sos_index": _C.DATA.SOS_INDEX, "eos_index": _C.DATA.EOS_INDEX, "decoder": CaptionDecoderFactory.from_config(_C), } elif _C.MODEL.NAME == "token_classification": kwargs = { "ignore_indices": [ _C.DATA.UNK_INDEX, _C.DATA.SOS_INDEX, _C.DATA.EOS_INDEX, _C.DATA.MASK_INDEX, ] } elif _C.MODEL.NAME == "multilabel_classification": kwargs = {"ignore_indices": [0]} # background index else: kwargs = {} return cls.create(_C.MODEL.NAME, visual, textual, **kwargs) class CaptionDecoderFactory(Factory): r""" Factory to create decoders from predicting captions from VirTex model. Possible choices: ``{"beam_search", "nucleus_sampling"}``. """ PRODUCTS: Dict[str, Callable] = { "beam_search": AutoRegressiveBeamSearch, "nucleus_sampling": AutoRegressiveNucleusSampling, } @classmethod def from_config(cls, config: Config) -> nn.Module: r""" Create a model directly from config. Args: config: Config object with all the parameters. """ _C = config kwargs = { "eos_index": _C.DATA.EOS_INDEX, "max_steps": _C.MODEL.DECODER.MAX_DECODING_STEPS, } if _C.MODEL.DECODER.NAME == "beam_search": kwargs["beam_size"] = _C.MODEL.DECODER.BEAM_SIZE elif _C.MODEL.DECODER.NAME == "nucleus_sampling": kwargs["nucleus_size"] = _C.MODEL.DECODER.NUCLEUS_SIZE return cls.create(_C.MODEL.DECODER.NAME, **kwargs) class OptimizerFactory(Factory): r"""Factory to create optimizers. Possible choices: ``{"sgd", "adamw"}``.""" PRODUCTS: Dict[str, Callable] = {"sgd": optim.SGD, "adamw": optim.AdamW} @classmethod def from_config( cls, config: Config, named_parameters: Iterable[Any] ) -> optim.Optimizer: r""" Create an optimizer directly from config. Args: config: Config object with all the parameters. named_parameters: Named parameters of model (retrieved by ``model.named_parameters()``) for the optimizer. We use named parameters to set different LR and turn off weight decay for certain parameters based on their names. """ _C = config # Set different learning rate for CNN and rest of the model during # pretraining. This doesn't matter for downstream evaluation because # there are no modules with "cnn" in their name. # Also turn off weight decay for layer norm and bias in textual stream. param_groups = [] for name, param in named_parameters: wd = 0.0 if re.match(_C.OPTIM.NO_DECAY, name) else _C.OPTIM.WEIGHT_DECAY lr = _C.OPTIM.CNN_LR if "cnn" in name else _C.OPTIM.LR param_groups.append({"params": [param], "lr": lr, "weight_decay": wd}) if _C.OPTIM.OPTIMIZER_NAME == "sgd": kwargs = {"momentum": _C.OPTIM.SGD_MOMENTUM} else: kwargs = {} optimizer = cls.create(_C.OPTIM.OPTIMIZER_NAME, param_groups, **kwargs) if _C.OPTIM.LOOKAHEAD.USE: optimizer = Lookahead( optimizer, k=_C.OPTIM.LOOKAHEAD.STEPS, alpha=_C.OPTIM.LOOKAHEAD.ALPHA ) return optimizer class LRSchedulerFactory(Factory): r""" Factory to create LR schedulers. All schedulers have a built-in LR warmup schedule before actual LR scheduling (decay) starts. Possible choices: ``{"none", "multistep", "linear", "cosine"}``. """ PRODUCTS: Dict[str, Callable] = { "none": lr_scheduler.LinearWarmupNoDecayLR, "multistep": lr_scheduler.LinearWarmupMultiStepLR, "linear": lr_scheduler.LinearWarmupLinearDecayLR, "cosine": lr_scheduler.LinearWarmupCosineAnnealingLR, } @classmethod def from_config( cls, config: Config, optimizer: optim.Optimizer ) -> optim.lr_scheduler.LambdaLR: r""" Create an LR scheduler directly from config. Args: config: Config object with all the parameters. optimizer: Optimizer on which LR scheduling would be performed. """ _C = config kwargs = { "total_steps": _C.OPTIM.NUM_ITERATIONS, "warmup_steps": _C.OPTIM.WARMUP_STEPS, } # Multistep LR requires multiplicative factor and milestones. if _C.OPTIM.LR_DECAY_NAME == "multistep": kwargs.update(gamma=_C.OPTIM.LR_GAMMA, milestones=_C.OPTIM.LR_STEPS) return cls.create(_C.OPTIM.LR_DECAY_NAME, optimizer, **kwargs) ================================================ FILE: virtex/model_zoo/__init__.py ================================================ from .model_zoo import get __all__ = ["get"] ================================================ FILE: virtex/model_zoo/model_zoo.py ================================================ r""" A utility module to easily load common VirTex models (optionally with pretrained weights) using a single line of code. Get our full best performing VirTex model (with pretrained weights as): >>> import virtex.model_zoo as mz >>> model = mz.get("width_ablations/bicaptioning_R_50_L1_H2048.yaml", pretrained=True) Any config available in ``configs/`` directory under project root can be specified here, although this command need not be executed from project root. For more details on available models, refer :doc:`usage/model_zoo`. Part of this code is adapted from Detectron2's model zoo; which was originally implemented by the developers of this codebase, with reviews and further changes by Detectron2 developers. """ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import os import pkg_resources from fvcore.common.download import download import torch from virtex.config import Config from virtex.factories import PretrainingModelFactory from virtex.utils.checkpointing import CheckpointManager class _ModelZooUrls: r"""Mapping from config names to URL suffixes of pretrained weights.""" URL_PREFIX = "https://www.dropbox.com/s" CONFIG_PATH_TO_DB_ID = { # Pretraining Task Ablations "task_ablations/bicaptioning_R_50_L1_H2048.yaml": "mbeeso8wyieq8wy", "task_ablations/captioning_R_50_L1_H2048.yaml": "r6zen9k43m5oo58", "task_ablations/token_classification_R_50.yaml": "o4p9lki505r0mef", "task_ablations/multilabel_classification_R_50.yaml": "hbspp3jv3u8h3bc", "task_ablations/masked_lm_R_50_L1_H2048.yaml": "ldzrk6vem4mg6bl", # Width Ablations "width_ablations/bicaptioning_R_50_L1_H512.yaml": "o9fr69jjqfn8a65", "width_ablations/bicaptioning_R_50_L1_H768.yaml": "1zxglqrrbfufv9d", "width_ablations/bicaptioning_R_50_L1_H1024.yaml": "pdat4tvhnqxel64", "width_ablations/bicaptioning_R_50_L1_H2048.yaml": "mbeeso8wyieq8wy", # Depth Ablations "depth_ablations/bicaptioning_R_50_L1_H1024.yaml": "pdat4tvhnqxel64", "depth_ablations/bicaptioning_R_50_L2_H1024.yaml": "ft1vtt4okirzjgo", "depth_ablations/bicaptioning_R_50_L3_H1024.yaml": "5ldo1rcsnrshmjr", "depth_ablations/bicaptioning_R_50_L4_H1024.yaml": "zgiit2wcluuq3xh", # Backbone Ablations "backbone_ablations/bicaptioning_R_50_L1_H1024.yaml": "pdat4tvhnqxel64", "backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml": "5o198ux709r6376", "backbone_ablations/bicaptioning_R_101_L1_H1024.yaml": "bb74jubt68cpn80", } def get(config_path: str, pretrained: bool = False): r""" Get a model specified by relative path under Detectron2's official ``configs/`` directory. Args: config_path: Name of config file relative to ``configs/`` directory under project root. (E.g. ``width_ablations/bicaptioning_R_50_L1_H2048.yaml``) pretrained: If ``True``, will initialize the model with the pretrained weights. If ``False``, the weights will be initialized randomly. """ # Get the original path to config file (shipped with inside the package). _pkg_config_path = pkg_resources.resource_filename( "virtex.model_zoo", os.path.join("configs", config_path) ) if not os.path.exists(_pkg_config_path): raise RuntimeError("{} not available in Model Zoo!".format(config_path)) _C = Config(_pkg_config_path) model = PretrainingModelFactory.from_config(_C) if pretrained: # Get URL for the checkpoint for this config path. if config_path in _ModelZooUrls.CONFIG_PATH_TO_DB_ID: dropbox_id = _ModelZooUrls.CONFIG_PATH_TO_DB_ID[config_path] filename = os.path.basename(config_path).replace(".yaml", ".pth") checkpoint_url = f"{_ModelZooUrls.URL_PREFIX}/{dropbox_id}/{filename}?dl=1" else: raise RuntimeError("{} not available in Model Zoo!".format(config_path)) # Download the pretrained model weights and save with a sensible name. # This will be downloaded only if it does not exist. checkpoint_path = download( checkpoint_url, dir=os.path.expanduser("~/.torch/virtex_cache"), filename=os.path.basename(config_path).replace(".yaml", ".pth") ) CheckpointManager(model=model).load(checkpoint_path) return model ================================================ FILE: virtex/models/__init__.py ================================================ from .captioning import ( ForwardCaptioningModel, BidirectionalCaptioningModel, VirTexModel ) from .masked_lm import MaskedLMModel from .classification import ( MultiLabelClassificationModel, TokenClassificationModel, ) __all__ = [ "VirTexModel", "BidirectionalCaptioningModel", "ForwardCaptioningModel", "MaskedLMModel", "MultiLabelClassificationModel", "TokenClassificationModel", ] ================================================ FILE: virtex/models/captioning.py ================================================ import copy import functools from typing import Any, Dict import torch from torch import nn from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.modules.textual_heads import TextualHead from virtex.modules.visual_backbones import VisualBackbone class CaptioningModel(nn.Module): r""" A model to perform image captioning (in both forward and backward directions independently, only in forward direction). It is composed of a :class:`~virtex.modules.visual_backbones.VisualBackbone` and a :class:`~virtex.modules.textual_heads.TextualHead` on top of it. During training, it maximizes the likelihood of ground truth caption conditioned on image features. During inference, it predicts a caption for an input image through beam search decoding. Args: visual: A :class:`~virtex.modules.visual_backbones.VisualBackbone` which computes visual features from an input image. textual: A :class:`~virtex.modules.textual_heads.TextualHead` which makes final predictions conditioned on visual features. sos_index: The index of the start token (``[SOS]``) in vocabulary. eos_index: The index of the end token (``[EOS]``) in vocabulary. caption_backward: Whether to *also* perform captioning in backward direction. Default is ``False`` -- only forward captioning is performed. When ``True``, a clone of textual head is created, which does not share weights with "forward" model except input/output embeddings. decoder: A :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` or :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling` object for decoding captions during inference (unused during training). """ def __init__( self, visual: VisualBackbone, textual: TextualHead, caption_backward: bool = False, sos_index: int = 1, eos_index: int = 2, decoder: Any = None, ): super().__init__() self.visual = visual self.textual = textual self.padding_idx = self.textual.padding_idx self.caption_backward = caption_backward # Clone the textual module for backward direction if doing captioning # in both directions (separately). if self.caption_backward: self.backward_textual = copy.deepcopy(self.textual) # Share weights for visual projection, and input/output embeddings. self.backward_textual.visual_projection = self.textual.visual_projection self.backward_textual.embedding = self.textual.embedding self.backward_textual.output = self.textual.output # These boundary indices are needed for beam search. self.sos_index = sos_index self.eos_index = eos_index self.decoder = decoder self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx) def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: r""" Given a batch of images and captions, compute log likelihood loss per caption token during training. During inference (with images), predict a caption through either beam search decoding or nucleus sampling. Args: batch: A batch of images and (optionally) ground truth caption tokens. Possible set of keys: ``{"image_id", "image", "caption_tokens", "noitpac_tokens", "caption_lengths"}``. Returns: A dict with the following structure, containing loss for optimization, loss components to log directly to tensorboard, and optionally predictions. .. code-block:: { "loss": torch.Tensor, "loss_components": { "captioning_forward": torch.Tensor, "captioning_backward": torch.Tensor, (optional) }, "predictions": torch.Tensor } """ # shape: (batch_size, channels, height, width) visual_features = self.visual(batch["image"]) batch_size = visual_features.size(0) if "caption_tokens" in batch: caption_tokens = batch["caption_tokens"] caption_lengths = batch["caption_lengths"] # shape: (batch_size, max_caption_length, vocab_size) output_logits = self.textual( visual_features, caption_tokens, caption_lengths ) loss = self.loss( output_logits[:, :-1].contiguous().view(-1, self.textual.vocab_size), caption_tokens[:, 1:].contiguous().view(-1), ) output_dict: Dict[str, Any] = { "loss": loss, # Single scalar per batch for logging in training script. "loss_components": {"captioning_forward": loss.clone().detach()}, } # Do captioning in backward direction if specified. if self.caption_backward: backward_caption_tokens = batch["noitpac_tokens"] backward_output_logits = self.backward_textual( visual_features, backward_caption_tokens, caption_lengths ) backward_loss = self.loss( backward_output_logits[:, :-1] .contiguous() .view(-1, self.textual.vocab_size), backward_caption_tokens[:, 1:].contiguous().view(-1), ) output_dict["loss"] += backward_loss # Single scalar per batch for logging in training script. output_dict["loss_components"].update( captioning_backward=backward_loss.clone().detach() ) if not self.training: # During validation (while pretraining), get best prediction # at every timestep. output_dict["predictions"] = torch.argmax(output_logits, dim=-1) else: if self.decoder is None: raise ValueError("Decoder for predicting captions is missing!") # During inference, get beam search predictions for forward # model. Predictions from forward transformer will be shifted # right by one timestep. start_predictions = visual_features.new_full( (batch_size,), self.sos_index ).long() # Add image features as a default argument to match callable # signature accepted by beam search class (partial captions only). decoding_step = functools.partial(self.decoding_step, visual_features) predicted_caption, _ = self.decoder.search( start_predictions, decoding_step ) output_dict = {"predictions": predicted_caption} return output_dict def decoding_step( self, visual_features: torch.Tensor, partial_captions: torch.Tensor ) -> torch.Tensor: r""" Given visual features and a batch of (assumed) partial captions, predict the logits over output vocabulary tokens for next timestep. This method is used by :class:`~virtex.utils.beam_search.AutoRegressiveBeamSearch` and :class:`~virtex.utils.nucleus_sampling.AutoRegressiveNucleusSampling`. .. note:: For nucleus sampling, ``beam_size`` will always be 1 (not relevant). Args: projected_visual_features: A tensor of shape ``(batch_size, ..., textual_feature_size)`` with visual features already projected to ``textual_feature_size``. partial_captions: A tensor of shape ``(batch_size * beam_size, timesteps)`` containing tokens predicted so far -- one for each beam. We need all prior predictions because our model is auto-regressive. Returns: A tensor of shape ``(batch_size * beam_size, vocab_size)`` -- logits over output vocabulary tokens for next timestep. """ # Expand and repeat image features while doing beam search. batch_size, channels, height, width = visual_features.size() beam_size = int(partial_captions.size(0) / batch_size) if beam_size > 1: # shape: (batch_size * beam_size, channels, height, width) visual_features = visual_features.unsqueeze(1).repeat(1, beam_size, 1, 1, 1) visual_features = visual_features.view( batch_size * beam_size, channels, height, width ) # Provide caption lengths as current length (irrespective of predicted # EOS/padding tokens). shape: (batch_size, ) caption_lengths = torch.ones_like(partial_captions) if len(caption_lengths.size()) == 2: caption_lengths = caption_lengths.sum(1) else: # Add a timestep. shape: (batch_size, 1) partial_captions = partial_captions.unsqueeze(1) # shape: (batch_size * beam_size, partial_caption_length, vocab_size) logits = self.textual(visual_features, partial_captions, caption_lengths) # Return logits from the last timestep. return logits[:, -1, :] def log_predictions( self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer ) -> str: self.eval() with torch.no_grad(): predictions = self.forward(batch)["predictions"] self.train() predictions_str = "" for tokens, preds in zip(batch["caption_tokens"], predictions): predictions_str += f""" Caption tokens : {" ".join(tokens.tolist())} Predictions (f): {" ".join(preds.tolist())} """ return predictions_str class ForwardCaptioningModel(CaptioningModel): r""" Convenient extension of :class:`~virtex.models.captioning.CaptioningModel` for better readability: this passes ``caption_backward=False`` to super class. """ def __init__( self, visual: VisualBackbone, textual: TextualHead, sos_index: int = 1, eos_index: int = 2, decoder: Any = None, ): super().__init__( visual, textual, sos_index=sos_index, eos_index=eos_index, caption_backward=False, decoder=decoder, ) class BidirectionalCaptioningModel(CaptioningModel): r""" Convenient extension of :class:`~virtex.models.captioning.CaptioningModel` for better readability: this passes ``caption_backward=True`` to super class. """ def __init__( self, visual: VisualBackbone, textual: TextualHead, sos_index: int = 1, eos_index: int = 2, decoder: Any = None, ): super().__init__( visual, textual, sos_index=sos_index, eos_index=eos_index, caption_backward=True, decoder=decoder, ) # Convenient handle for our main model. VirTexModel = BidirectionalCaptioningModel ================================================ FILE: virtex/models/classification.py ================================================ from typing import Any, Dict, List import torch from torch import nn from torch.nn import functional as F from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.modules.textual_heads import TextualHead from virtex.modules.visual_backbones import VisualBackbone class ClassificationModel(nn.Module): r""" A model to perform classification (generally, with multiple targets). It is composed of a :class:`~virtex.modules.visual_backbones.VisualBackbone` and a :class:`~virtex.modules.textual_heads.TextualHead` on top of it. .. note:: As with currently available textual heads, only one textual head is supported here: :class:`~virtex.modules.textual_heads.LinearTextualHead`. During training, it minimizes the KL-divergence loss with a K-hot vector, with values ``1/K``, where K are the number of unique labels to classify. Args: visual: A :class:`~virtex.modules.visual_backbones.VisualBackbone` which computes visual features from an input image. textual: A :class:`~virtex.modules.textual_heads.TextualHead` which makes final predictions conditioned on visual features. ignore_indices: Ignore a set of token indices while computing KL-divergence loss. These are special tokens such as ``[SOS]``, ``[EOS]`` etc. """ def __init__( self, visual: VisualBackbone, textual: TextualHead, ignore_indices: List[int] ): super().__init__() self.visual = visual self.textual = textual self.ignore_indices = ignore_indices def forward(self, batch: Dict[str, torch.Tensor]): r""" Given a batch of images and set of labels, perform classification with multiple targets by minimizing a KL-divergence loss. Args: batch: A batch of images and labels. Possible set of keys: ``{"image_id", "image", "labels"}`` Returns: A dict with the following structure, containing loss for optimization, loss components to log directly to tensorboard, and optionally predictions. .. code-block:: { "loss": torch.Tensor, "loss_components": { "classification": torch.Tensor, }, "predictions": torch.Tensor } """ # shape: (batch_size, visual_feature_size, ...) visual_features = self.visual(batch["image"]) batch_size = visual_features.size(0) # Get logits and further log-probabilities. # shape: (batch_size, vocab_size) logits = self.textual(visual_features) logprobs = F.log_softmax(logits, dim=1) # Average log-probs per unique token in associated caption to compute # loss. This is simply cross-entropy with target-vector as a K-hot # vector. Do in a for-loop, there isn't a straightforward vectorized way. loss = torch.tensor(0.0, device=logprobs.device) for index in range(batch_size): # Get unique labels for particular instance. unique_labels = batch["labels"][index].unique() # Ignore indices of special tokens such as [SOS], [EOS] etc. and # any other token specified. unique_labels = [l for l in unique_labels if l not in self.ignore_indices] # Get log-probabilities corresponding to these tokens. instance_logprobs = logprobs[index, unique_labels].mean() # Accumulate negative log-probability for this instance in loss. loss = loss - instance_logprobs # Average loss across instances. output_dict: Dict[str, Any] = {"loss": loss / batch_size} # Single scalar per batch for logging to tensorboard in training script. output_dict["loss_components"] = { "classification": loss.clone().detach() / batch_size } # Return top-10 tokens according to log-probabilities during validation. # Useful for logging. if not self.training: top_logprobs, top_tokens = logprobs.topk(k=10, dim=1) output_dict["predictions"] = top_tokens return output_dict class TokenClassificationModel(ClassificationModel): r""" Convenient extension of :class:`~virtex.models.classification.ClassificationModel` for better readability (this only modifies the tensorboard logging logic). Ground truth targets here are a set of unique caption tokens (ignoring the special tokens like ``[SOS]``, ``[EOS]`` etc.). """ def log_predictions( self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer ) -> str: self.eval() with torch.no_grad(): predictions = self.forward(batch)["predictions"] self.train() predictions_str = "" for tokens, preds in zip(batch["caption_tokens"], predictions): # Predictions here are individual tokens, and do not have any order # like captions, so decode them separately so we don't strip off # metaspace character and special tokens if any. preds = [tokenizer.id_to_token(p) for p in preds.tolist()] predictions_str += f""" Caption tokens : {tokenizer.decode(tokens.tolist())} Predictions (f): {" ".join(preds)} """ return predictions_str class MultiLabelClassificationModel(ClassificationModel): r""" Convenient extension of :class:`~virtex.models.classification.ClassificationModel` for better readability (this only modifies the tensorboard logging logic). Ground truth targets here are a set of unique instances in images (ignoring the special background token, category id = 0 in COCO). """ def log_predictions( self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer = None, ) -> str: # We accept `tokenizer` for having consistent API but don't use it here. self.eval() with torch.no_grad(): predictions = self.forward(batch)["predictions"] self.train() predictions_str = "" for tokens, preds in zip(batch["caption_tokens"], predictions): # Predictions here are COCO category IDs, let them be as is. # Sorted ground truth, remove background tokens. tokens = sorted([t for t in tokens.tolist() if t != 0]) preds = sorted(preds.tolist()[: len(tokens)]) predictions_str += f""" COCO Instance IDs (GT) : {tokens} COCO Instance IDs (Pred) : {preds} """ return predictions_str ================================================ FILE: virtex/models/masked_lm.py ================================================ from typing import Any, Dict import torch from torch import nn from virtex.data.tokenizers import SentencePieceBPETokenizer from virtex.modules.textual_heads import TextualHead from virtex.modules.visual_backbones import VisualBackbone class MaskedLMModel(nn.Module): r""" A model to perform BERT-like masked language modeling. It is composed of a :class:`~virtex.modules.visual_backbones.VisualBackbone` and a :class:`~virtex.modules.textual_heads.TextualHead` on top of it. During training, the model received caption tokens with certain tokens replaced by ``[MASK]`` token, and it predicts these masked tokens based on surrounding context. Args: visual: A :class:`~virtex.modules.visual_backbones.VisualBackbone` which computes visual features from an input image. textual: A :class:`~virtex.modules.textual_heads.TextualHead` which makes final predictions conditioned on visual features. """ def __init__(self, visual: VisualBackbone, textual: TextualHead): super().__init__() self.visual = visual self.textual = textual self.padding_idx = self.textual.padding_idx self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx) def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, Any]: r""" Given a batch of images and captions with certain masked tokens, predict the tokens at masked positions. Args: batch: A batch of images, ground truth caption tokens and masked labels. Possible set of keys: ``{"image_id", "image", "caption_tokens", "masked_labels", "caption_lengths"}``. Returns: A dict with the following structure, containing loss for optimization, loss components to log directly to tensorboard, and optionally predictions. .. code-block:: { "loss": torch.Tensor, "loss_components": {"masked_lm": torch.Tensor}, "predictions": torch.Tensor } """ # shape: (batch_size, channels, height, width) visual_features = self.visual(batch["image"]) caption_tokens = batch["caption_tokens"] caption_lengths = batch["caption_lengths"] masked_labels = batch["masked_labels"] # shape: (batch_size, num_caption_tokens, vocab_size) output_logits = self.textual(visual_features, caption_tokens, caption_lengths) output_dict: Dict[str, Any] = { "loss": self.loss( output_logits.view(-1, output_logits.size(-1)), masked_labels.view(-1) ) } # Single scalar per batch for logging in training script. output_dict["loss_components"] = { "masked_lm": output_dict["loss"].clone().detach() } # During evaluation, get predictions from logits. Useful for logging. # Only the predictions at [MASK]ed positions are relevant. if not self.training: predictions = torch.argmax(output_logits, dim=-1) redundant_positions = masked_labels == self.padding_idx predictions[redundant_positions] = self.padding_idx output_dict["predictions"] = predictions return output_dict def log_predictions( self, batch: Dict[str, torch.Tensor], tokenizer: SentencePieceBPETokenizer ) -> str: self.eval() with torch.no_grad(): predictions = self.forward(batch)["predictions"] self.train() predictions_str = "" for tokens, labels, preds in zip( batch["caption_tokens"], batch["masked_labels"], predictions ): predictions_str += f""" Caption tokens : {tokenizer.decode(tokens.tolist())} Masked Labels : {tokenizer.decode(labels.tolist())} Predictions : {tokenizer.decode(preds.tolist())} """ return predictions_str ================================================ FILE: virtex/modules/embedding.py ================================================ import functools import torch from torch import nn class WordAndPositionalEmbedding(nn.Module): r""" A :class:`~torch.nn.Module` for learned word embeddings and position embeddings for input tokens. Each token is mapped to a fixed dimensional word embedding; and corresponding positional embedding based on its index. These are summed together followed by layer normalization and an optional dropout. Args: vocab_size: Size of token vocabulary. hidden_size: Size of token embedding vectors. dropout: Probability for final dropout applied after layer normalization. max_caption_length: Maximum length of input captions; this is used to create a fixed positional embedding lookup table. padding_idx: Token index of ``[PAD]`` token, word embedding for these tokens will be a vector of zeroes (and not trainable). """ def __init__( self, vocab_size: int, hidden_size: int, dropout: float = 0.0, max_caption_length: int = 30, padding_idx: int = 0, ): super().__init__() self.vocab_size = vocab_size self.padding_idx = padding_idx self.words = nn.Embedding(vocab_size, hidden_size, padding_idx=padding_idx) # We provide no "padding index" for positional embeddings. We zero out # the positional embeddings of padded positions as a post-processing. self.positions = nn.Embedding(max_caption_length, hidden_size) self.layer_norm = nn.LayerNorm( hidden_size, eps=1e-8, elementwise_affine=True ) self.dropout = nn.Dropout(p=dropout) def forward(self, tokens: torch.Tensor) -> torch.Tensor: r""" Get combined word and positional embeddings for input tokens. Args: tokens: A tensor of shape ``(batch_size, max_caption_length)`` containing a batch of caption tokens, values in ``[0, vocab_size)``. Returns: A tensor of shape ``(batch_size, max_caption_length, hidden_size)`` containing corresponding token embeddings. """ position_indices = self._create_position_indices(tokens) # shape: (batch_size, max_caption_length, hidden_size) word_embeddings = self.words(tokens) position_embeddings = self.positions(position_indices) # shape: (batch_size, max_caption_length, hidden_size) embeddings = self.layer_norm(word_embeddings + position_embeddings) embeddings = self.dropout(embeddings) # Zero-out embeddings for positions which have padding tokens. # shape: (batch_size, max_caption_length, 1) token_mask = (tokens != self.padding_idx).unsqueeze(-1) # shape: (batch_size, max_caption_length, hidden_size) embeddings = embeddings * token_mask.type(embeddings.dtype) return embeddings @functools.lru_cache(maxsize=128) def _create_position_indices(self, tokens: torch.Tensor): # Create position indices of the same size as token indices. batch_size, max_caption_length = tokens.size() positions = torch.arange( max_caption_length, dtype=tokens.dtype, device=tokens.device ) # shape: (batch_size, max_caption_length) positions = positions.unsqueeze(0).expand(batch_size, max_caption_length) return positions ================================================ FILE: virtex/modules/textual_heads.py ================================================ r""" A textual head accepts visual features from the visual backbone, and performs task specific modeling (captioning, classification etc.) to predict an output distribution over vocabulary tokens for one or multiple time-steps in the batch. """ import functools import torch from torch import nn from typing import Optional from virtex.modules.embedding import WordAndPositionalEmbedding class TextualHead(nn.Module): r""" Base class for all textual heads. All child classes can simply inherit from :class:`~torch.nn.Module`, however this is kept here for uniform type annotations. Args: visual_feature_size: Size (number of channels) of the input features from the visual backbone. vocab_size: Number of tokens in the output vocabulary. hidden_size: Size of the token embedding vectors, or hidden state vector of the language model. """ def __init__(self, visual_feature_size: int, vocab_size: int, hidden_size: int): super().__init__() self.visual_feature_size = visual_feature_size self.vocab_size = vocab_size self.hidden_size = hidden_size @property def textual_feature_size(self): r""" Size of the last dimension of output right before the output linear layer (which predicts a distribution over vocabulary tokens). This is typically same as :attr:`hidden_size` for most modules. This property is used to add more modules on top of this. """ return self.hidden_size class LinearTextualHead(TextualHead): r""" A textual head containing a single linear layer projecting from the visual feature size to the output vocabulary size. Args: visual_feature_size: Size (number of channels) of the input features from the visual backbone. vocab_size: Number of tokens in the output vocabulary. """ def __init__(self, visual_feature_size: int, vocab_size: int, **kwargs): # For API consistency. hidden_size = visual_feature_size super().__init__(visual_feature_size, vocab_size, hidden_size) self.output = nn.Linear(visual_feature_size, vocab_size) def forward( self, visual_features: torch.Tensor, caption_tokens: Optional[torch.Tensor] = None, caption_lengths: Optional[torch.Tensor] = None, ) -> torch.Tensor: r""" Project visual features directly to predict a distribution over vocabulary tokens through a single linear layer. This textual head ignores arguments ``caption_tokens`` and ``caption_lengths``, they are here for API consistency. Args: visual_features: A tensor of shape ``(batch_size, channels, height, width)`` containing features from visual backbone. Returns: A tensor of shape ``(batch_size, vocab_size)`` containing output vocabulary logits. """ # Convert to NHWC and project visual features to textual feature size. batch_size, channels, _, _ = visual_features.size() visual_features = visual_features.view(batch_size, channels, -1) visual_features = visual_features.permute(0, 2, 1) # Perform global average pooling of visual features. # shape: (batch_size, channels) visual_features = visual_features.mean(dim=1) # shape: (batch_size, max_caption_length, vocab_size) output_logits = self.output(visual_features) return output_logits class TransformerDecoderTextualHead(TextualHead): r""" A textual head composed of four main modules: (1) input projection (linear layer) for visual features to match size with textual features, (2) word and positional embedding for input captions, (3) a unidirectional transformer decoder, and (4) and output projection (linear layer) to predict a distribution over vocabulary tokens. The word embedding weights are tied with output projection; the latter still has its own learnable bias. .. note:: For the "bicaptioning" pretraining task, our *textual head* (as defined in the paper) must have two transformer decoders: one each to decode caption in either direction. This class however will always have one transformer per object. Refer :class:`~virtex.models.captioning.BidirectionalCaptioningModel` source to understand how an object of this class is cloned, along with tying embedding and output weights, for bicaptioning. Hence, while there are *two objects* of this class, it is pragmatically a *single* textual head as a whole, according to the terminology used in paper. Args: visual_feature_size: Size (number of channels) of the input features from the visual backbone. vocab_size: Number of tokens in the output vocabulary. hidden_size: Size of the token embedding vectors, or hidden state vector of the language model. num_layers: Number of layers in the transformer. attention_heads: Number of attention heads in the transformer. feedforward_size: Size of feedforward layers in the transformer. dropout: Dropout probability for transformer (applied after layernorm). norm_first: Whether to apply normalization before or after attention/FF layers. The former type are called pre-norm variants (like GPT-2) and latter are post-norm variants (like BERT). Default is post-norm. mask_future_positions: Whether to mask future positions for self-attention over caption tokens. This must be ``True`` for captioning (and bicaptioning) tasks to prevent the language model from cheating, and ``False`` for masked language modeling, as the self-attention should consider all tokens. max_caption_length: Maximum length of input captions; this is used to create a fixed positional embedding lookup table. padding_idx: Token index of ``[PAD]`` token, word embedding for these tokens will be a vector of zeroes (and not trainable). """ def __init__( self, visual_feature_size: int, vocab_size: int, hidden_size: int, num_layers: int, attention_heads: int, feedforward_size: int, dropout: float = 0.1, norm_first: bool = False, mask_future_positions: bool = True, max_caption_length: int = 30, padding_idx: int = 0, ): super().__init__(visual_feature_size, vocab_size, hidden_size) self.num_layers = num_layers self.attention_heads = attention_heads self.feedforward_size = feedforward_size self.dropout = dropout self.mask_future_positions = mask_future_positions self.padding_idx = padding_idx self.visual_projection = nn.Linear( visual_feature_size, self.textual_feature_size ) self.embedding = WordAndPositionalEmbedding( self.vocab_size, self.textual_feature_size, dropout=dropout, max_caption_length=max_caption_length, padding_idx=padding_idx, ) # Initialize a transformer with given transformer class (for example # nn.TransformerEncoder and nn.TransformerEncoderLayer). self.transformer = nn.TransformerDecoder( nn.TransformerDecoderLayer( self.textual_feature_size, self.attention_heads, dim_feedforward=self.feedforward_size, dropout=dropout, activation="gelu", batch_first=True, norm_first=norm_first, ), num_layers=self.num_layers, # Add final layer norm for pre-norm transformers. norm=nn.LayerNorm(self.hidden_size) if norm_first else None, ) self.apply(self._init_weights) # Create an output linear layer and tie the input and output word # embeddings to reduce parameters. self.output = nn.Linear(self.textual_feature_size, vocab_size) self.output.weight = self.embedding.words.weight @staticmethod def _init_weights(module): r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0.""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.MultiheadAttention): module.in_proj_weight.data.normal_(mean=0.0, std=0.02) module.out_proj.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def forward( self, visual_features: torch.Tensor, caption_tokens: torch.Tensor, caption_lengths: torch.Tensor, ) -> torch.Tensor: r""" Given (projected) visual features from visual backbone and caption tokens, predict the output logits for next time-step. Args: visual_features: A tensor of shape ``(batch_size, channels, height, width)`` containing features from visual backbone. caption_tokens: A tensor of shape ``(batch_size, max_caption_length)`` of caption tokens padded to the right by ``padding_idx``. caption_lengths: A tensor of shape ``(batch_size, )`` containing lengths of caption tokens in the batch. Returns: A tensor of shape ``(batch_size, max_caption_length, vocab_size)`` containing output vocabulary logits for each time-step. """ # Convert to NHWC and project visual features to textual feature size. batch_size, channels, height, width = visual_features.size() visual_features = visual_features.view(batch_size, channels, -1) visual_features = visual_features.permute(0, 2, 1) # shape: (batch_size, height * width, textual_feature_size) projected_visual_features = self.visual_projection(visual_features) # Now visual and textual features are of same size. # Note that `max_caption_length` here may be less than the # `max_caption_length` passed in `__init__`, but it does not matter. batch_size, max_caption_length = caption_tokens.size() # Create a mask based on caption lengths, shape: (batch_size, ) # Form a binary mask: it is True for padding positions. # These positions will be ignored for multi-headed attention. ones = torch.ones_like(caption_tokens) caption_mask = caption_lengths.unsqueeze(1) < ones.cumsum(dim=1) # shape: (batch_size, max_caption_length, textual_feature_size) caption_embeddings = self.embedding(caption_tokens) if self.mask_future_positions: # An additive mask for masking the future (one direction). future_mask = self.make_future_mask( max_caption_length, caption_embeddings.dtype, caption_embeddings.device ) else: future_mask = None # shape: (batch_size, max_caption_length, hidden_size) textual_features = self.transformer( caption_embeddings, projected_visual_features, tgt_mask=future_mask, tgt_key_padding_mask=caption_mask, ) # shape: (batch_size, max_caption_length, vocab_size) output_logits = self.output(textual_features) return output_logits @staticmethod @functools.cache def make_future_mask( size: int, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: """ Generate a mask for "future" positions. Masked positions will be negative infinity. This mask is critical for casual language modeling. """ return torch.triu( torch.full((size, size), float("-inf"), dtype=dtype, device=device), diagonal=1, ) ================================================ FILE: virtex/modules/visual_backbones.py ================================================ from typing import Any, Dict import torch from torch import nn import torchvision class VisualBackbone(nn.Module): r""" Base class for all visual backbones. All child classes can simply inherit from :class:`~torch.nn.Module`, however this is kept here for uniform type annotations. """ def __init__(self, visual_feature_size: int): super().__init__() self.visual_feature_size = visual_feature_size class TorchvisionVisualBackbone(VisualBackbone): r""" A visual backbone from `Torchvision model zoo `_. Any model can be specified using corresponding method name from the model zoo. Args: name: Name of the model from Torchvision model zoo. visual_feature_size: Size of the channel dimension of output visual features from forward pass. pretrained: Whether to load ImageNet pretrained weights from Torchvision. frozen: Whether to keep all weights frozen during training. """ def __init__( self, name: str = "resnet50", visual_feature_size: int = 2048, pretrained: bool = False, frozen: bool = False, ): super().__init__(visual_feature_size) self.cnn = getattr(torchvision.models, name)( pretrained, zero_init_residual=True ) # Do nothing after the final residual stage. self.cnn.fc = nn.Identity() # Freeze all weights if specified. if frozen: for param in self.cnn.parameters(): param.requires_grad = False self.cnn.eval() def forward(self, image: torch.Tensor) -> torch.Tensor: r""" Compute visual features for a batch of input images. Args: image: Batch of input images. A tensor of shape ``(batch_size, 3, height, width)``. Returns: A tensor of shape ``(batch_size, channels, height, width)``, for example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50. """ for idx, (name, layer) in enumerate(self.cnn.named_children()): out = layer(image) if idx == 0 else layer(out) # These are the spatial features we need. if name == "layer4": # shape: (batch_size, channels, height, width) return out def detectron2_backbone_state_dict(self) -> Dict[str, Any]: r""" Return state dict of visual backbone which can be loaded with `Detectron2 `_. This is useful for downstream tasks based on Detectron2 (such as object detection and instance segmentation). This method renames certain parameters from Torchvision-style to Detectron2-style. Returns: A dict with three keys: ``{"model", "author", "matching_heuristics"}``. These are necessary keys for loading this state dict properly with Detectron2. """ # Detectron2 backbones have slightly different module names, this mapping # lists substrings of module names required to be renamed for loading a # torchvision model into Detectron2. DETECTRON2_RENAME_MAPPING: Dict[str, str] = { "layer1": "res2", "layer2": "res3", "layer3": "res4", "layer4": "res5", "bn1": "conv1.norm", "bn2": "conv2.norm", "bn3": "conv3.norm", "downsample.0": "shortcut", "downsample.1": "shortcut.norm", } # Populate this dict by renaming module names. d2_backbone_dict: Dict[str, torch.Tensor] = {} for name, param in self.cnn.state_dict().items(): for old, new in DETECTRON2_RENAME_MAPPING.items(): name = name.replace(old, new) # First conv and bn module parameters are prefixed with "stem.". if not name.startswith("res"): name = f"stem.{name}" d2_backbone_dict[name] = param return { "model": d2_backbone_dict, "__author__": "Karan Desai", "matching_heuristics": True, } ================================================ FILE: virtex/optim/__init__.py ================================================ from .lookahead import Lookahead __all__ = ["Lookahead"] ================================================ FILE: virtex/optim/lookahead.py ================================================ r""" `Lookahead Optimizer: k steps forward, 1 step back `_. This implementation is adapted with minimal modifications from the `authors' implementation `_. If you take it from here, please cite them: .. code-block:: text @inproceedings{zhang2019lookahead, title={Lookahead Optimizer: k steps forward, 1 step back}, author={Zhang, Michael R and Lucas, James and Hinton, Geoffrey and Ba, Jimmy}, journal={NeurIPS}, year={2019} } """ from collections import defaultdict from typing import Any, Callable, Dict import torch from torch.optim.optimizer import Optimizer class Lookahead(Optimizer): r""" Implements Lookahead optimizer. Args: optimizer: Wrapper inner optimizer. The weights it manages will be the "fast" weights. k: Number of lookahead steps before updating "slow" weights. alpha: Linear interpolation factor, 1.0 recovers inner optimizer. """ def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.8): self.optimizer = optimizer self.k = k self.alpha = alpha # Counter for inner optimizer. self._k_counter = 0 # Cache the current optimizer parameters self.state: Dict[str, Any] = defaultdict(dict) for group in optimizer.param_groups: for p in group["params"]: param_state = self.state[p] param_state["slow_params"] = torch.zeros_like(p.data) param_state["slow_params"].copy_(p.data) def __getstate__(self): return { "state": self.state, "optimizer": self.optimizer, "alpha": self.alpha, "k": self.k, "_k_counter": self._k_counter, } @property def param_groups(self): return self.optimizer.param_groups def zero_grad(self): r"""Clear all grad buffers at the start of new forward pass.""" self.optimizer.zero_grad() def state_dict(self): return self.optimizer.state_dict() def load_state_dict(self, state_dict: Dict[str, Any]): self.optimizer.load_state_dict(state_dict) # Cache optimizer parameters after loading state dict. for group in self.optimizer.param_groups: for p in group["params"]: param_state = self.state[p] param_state["slow_params"] = torch.zeros_like(p.data) param_state["slow_params"].copy_(p.data) def step(self, closure: Callable = None): r""" Perform a single Lookahead optimization step. Args: closure: A callable that re-evaluates the model and returns loss. """ loss = self.optimizer.step(closure) self._k_counter += 1 if self._k_counter >= self.k: self._k_counter = 0 # Lookahead and cache the current optimizer parameters for group in self.optimizer.param_groups: for p in group["params"]: param_state = self.state[p] p.data.mul_(self.alpha).add_( param_state["slow_params"], alpha=1.0 - self.alpha ) param_state["slow_params"].copy_(p.data) return loss def load_slow_weights(self): r""" Load slow weights from Lookahead optimizer. Useful for performing evaluation on the slow weights (which typically generalize better). This method backs up fast weights to load them after evaluation. No need to call this method if evaluation happens just after a lookahead step. """ for group in self.optimizer.param_groups: for p in group["params"]: param_state = self.state[p] param_state["backup_params"] = torch.zeros_like(p.data) param_state["backup_params"].copy_(p.data) p.data.copy_(param_state["slow_params"]) def restore_fast_weights(self): r""" Restore fast weights for optimization. Call this after evaluation if :meth:`load_slow_weights` was called. """ for group in self.optimizer.param_groups: for p in group["params"]: param_state = self.state[p] p.data.copy_(param_state["backup_params"]) del param_state["backup_params"] ================================================ FILE: virtex/optim/lr_scheduler.py ================================================ import bisect import math from typing import List from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR class LinearWarmupNoDecayLR(LambdaLR): r""" A learning rate scheduler which linearly increases learning rate from 0 LR, and further keeps it constant throughout training. Args: optimizer: Wrapped optimizer. total_steps: Total epochs (or iterations) for training. warmup_steps: Number of first few steps to do linear warmup. last_epoch: The index of last step (epoch or iteration). We named it ``last_epoch`` instead of ``last_step`` to keep the naming consistent with other LR schedulers in PyTorch. """ def __init__( self, optimizer: Optimizer, total_steps: int, warmup_steps: int, last_epoch: int = -1, ): assert ( warmup_steps < total_steps ), "Warmup steps should be less than total steps." self.tsteps = total_steps self.wsteps = warmup_steps super().__init__(optimizer, self._lr_multiplier, last_epoch) def _lr_multiplier(self, step: int) -> float: multiplier = step / float(max(1, self.wsteps)) if step < self.wsteps else 1 return max(0, multiplier) class LinearWarmupMultiStepLR(LambdaLR): r""" A learning rate scheduler which linearly increases learning rate from 0 LR, and further decreases it by gamma once the number of steps reaches one of the milestones. Args: optimizer: Wrapped optimizer. total_steps: Total epochs (or iterations) for training. warmup_steps: Number of first few steps to do linear warmup. last_epoch: The index of last step (epoch or iteration). We named it ``last_epoch`` instead of ``last_step`` to keep the naming consistent with other LR schedulers in PyTorch. milestones: List of step indices (epochs or iterations depending on context). Must be increasing. gamma: Multiplicative factor of learning rate decay. last_epoch: The index of last step (epoch or iteration). We named it ``last_epoch`` instead of ``last_step`` to keep the naming consistent with other LR schedulers in PyTorch. """ def __init__( self, optimizer: Optimizer, total_steps: int, warmup_steps: int, milestones: List[int], gamma: float = 0.1, last_epoch: int = -1, ): self.wsteps = warmup_steps self.milestones = milestones self.gamma = gamma # Keep a track of number of milestones encountered. self.milestones_so_far = 0 # Common sanity checks. assert milestones == sorted(milestones), "milestones must be increasing" assert milestones[0] > warmup_steps, "first milestone must be after warmup" assert ( milestones[-1] < total_steps ), "last milestone must be less than total steps" super().__init__(optimizer, self._lr_multiplier, last_epoch) def _lr_multiplier(self, step: int) -> float: if step < self.wsteps: # Linear warmup. multiplier = step / float(max(1, self.wsteps)) else: # Step decay based on milestones. multiplier = self.gamma ** bisect.bisect_right(self.milestones, step) # Avoid negative learning rate. return max(0, multiplier) class LinearWarmupLinearDecayLR(LambdaLR): r""" A learning rate scheduler which linearly increases learning rate from 0 LR, and further decreases it linearly to zero. Args: optimizer: Wrapped optimizer. total_steps: Total epochs (or iterations) for training. warmup_steps: Number of first few steps to do linear warmup. last_epoch: The index of last step (epoch or iteration). We named it ``last_epoch`` instead of ``last_step`` to keep the naming consistent with other LR schedulers in PyTorch. """ def __init__( self, optimizer: Optimizer, total_steps: int, warmup_steps: int, last_epoch: int = -1, ): assert ( warmup_steps < total_steps ), "Warmup steps should be less than total steps." self.tsteps = total_steps self.wsteps = warmup_steps super().__init__(optimizer, self._lr_multiplier, last_epoch) def _lr_multiplier(self, step: int) -> float: if step < self.wsteps: # Linear warmup. multiplier = step / float(max(1, self.wsteps)) else: # Linear decay. multiplier = (self.tsteps - step) / (self.tsteps - self.wsteps) # Avoid negative learning rate. return max(0, multiplier) class LinearWarmupCosineAnnealingLR(LambdaLR): r""" A learning rate scheduler which linearly increases learning rate from 0 LR, and further decreases it to zero by cosine decay. After linear warmup, the LR decays as: .. math:: \eta_t = \eta_{max}\cos^2(\frac{T_{cur} - T_{warm}}{T_{max} - T_{warm}}\frac{\pi}{2}) Args: optimizer: Wrapped optimizer. total_steps: Total epochs (or iterations) for training. warmup_steps: Number of first few steps to do linear warmup. last_epoch: The index of last step (epoch or iteration). We named it ``last_epoch`` instead of ``last_step`` to keep the naming consistent with other LR schedulers in PyTorch. """ def __init__( self, optimizer: Optimizer, total_steps: int, warmup_steps: int, last_epoch: int = -1, ): assert ( warmup_steps < total_steps ), "Warmup steps should be less than total steps." self.tsteps = total_steps self.wsteps = warmup_steps super().__init__(optimizer, self._lr_multiplier, last_epoch) def _lr_multiplier(self, step: int) -> float: if step < self.wsteps: # Linear warmup. multiplier = step / float(max(1, self.wsteps)) else: # Cosine annealing decay. cos_factor = (step - self.wsteps) / (self.tsteps - self.wsteps) multiplier = math.cos(cos_factor * (math.pi / 2)) ** 2 # Avoid negative learning rate. return max(0, multiplier) ================================================ FILE: virtex/utils/beam_search.py ================================================ r""" This Beam Search implementation is adapted with minor modifications from `AllenNLP `_. Thanks to the developers of AllenNLP! **Update (v1.2):** The "backpointer" trick in Beam Search (as implemented in AllenNLP) does not work well with autoregressive models (transformers). It is now removed and it improves qualitative predictions and captioning metrics (CIDEr/SPICE) for VirTex. Updated captioning results are on ArXiv v3. Refer `CHANGELOG `_ and `Release Page `_ for more details. Huge thanks to Nicolas Carion (@alcinos) and Aishwarya Kamath (@ashkamath) for helping me fix this bug! """ from typing import Callable, Tuple import warnings import torch from torch.nn import functional as F class AutoRegressiveBeamSearch: r""" Implements the beam search algorithm for decoding the most likely captions. Args: eos_index: The index of the end token (``[EOS]``) in vocabulary. max_steps: The maximum number of decoding steps. beam_size: The width of the beam used. per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search. Setting this parameter to a number smaller than ``beam_size`` may give better results, as it can introduce more diversity into the search. See `Beam Search Strategies for Neural Machine Translation. Freitag and Al-Onaizan, 2017 `_. """ def __init__( self, eos_index: int, max_steps: int = 50, beam_size: int = 5, per_node_beam_size: int = 2, ) -> None: self._eos_index = eos_index self.max_steps = max_steps self.beam_size = beam_size self.per_node_beam_size = per_node_beam_size or beam_size def search( self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor], only_return_best: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Given a starting state and a step function, apply beam search to find the most likely target captions. Args: start_predictions: Tensor containing the initial predictions, shape ``(batch_size, )``. Usually the initial predictions are just the index of the start token (``[SOS]``) in the vocabulary. step: A function that is responsible for computing the next most likely tokens, given the past predictions. Predictions from all previous timesteps are required, not just the last timestep. The function is expected to return a tensor of shape ``(group_size, target_vocab_size)`` containing the token logits for the next step. only_return_best: Whether to only return the best beam (with highest logprobs). Set this to ``False`` to return all the beams. If this is ``True``, then the returned tensor is of shape ``(batch_size, sequence_length)``, else will be ``(batch_size, beam_size, sequence_length)``. Returns: Tuple of ``(predictions, logprobs)``, where ``predictions`` has shape ``(batch_size, beam_size, max_steps)`` and ``logprobs`` has shape ``(batch_size, beam_size)``. """ batch_size = start_predictions.size()[0] # List of `(batch_size, beam_size, length)` tensors. # Does not include the start symbols, which are implicit. predictions: torch.Tensor = torch.empty( (batch_size, self.beam_size, 0), dtype=torch.long, device=start_predictions.device, ) # Calculate the first timestep. This is done outside the main loop # because we are going from a single decoder input (the output from the # encoder) to the top `beam_size` decoder outputs. On the other hand, # within the main loop we are going from the `beam_size` elements of the # beam to `beam_size`^2 candidates from which we will select the top # `beam_size` elements for the next iteration. # shape: (batch_size, num_classes) start_class_logits = step(start_predictions) # Convert logits to logprobs. # shape: (batch_size * beam_size, vocab_size) start_class_logprobs = F.log_softmax(start_class_logits, dim=1) num_classes = start_class_logprobs.size()[1] # shape: (batch_size, beam_size), (batch_size, beam_size) start_top_logprobs, start_predicted_classes = start_class_logprobs.topk( self.beam_size ) if self.beam_size == 1 and (start_predicted_classes == self._eos_index).all(): warnings.warn( "Empty captions predicted. You may want to increase beam " "size or ensure your step function is working properly.", RuntimeWarning, ) return start_predicted_classes.unsqueeze(-1), start_top_logprobs # The log probs for the last time step. # shape: (batch_size, beam_size) last_logprobs = start_top_logprobs # shape: (batch_size, beam_size, sequence_length) predictions = torch.cat( [predictions, start_predicted_classes.unsqueeze(-1)], dim=-1 ) # Log probability tensor that mandates that the end token is selected. # shape: (batch_size * beam_size, num_classes) logprobs_after_end = start_class_logprobs.new_full( (batch_size * self.beam_size, num_classes), float("-inf") ) logprobs_after_end[:, self._eos_index] = 0.0 for timestep in range(self.max_steps - 1): # shape: (batch_size * beam_size,) last_predictions = predictions[:, :, -1].reshape( batch_size * self.beam_size ) # If every predicted token from the last step is `self._eos_index`, # then we can stop early. if (last_predictions == self._eos_index).all(): break predictions_so_far = predictions.view(batch_size * self.beam_size, -1) # shape: (batch_size * beam_size, num_classes) class_logits = step(predictions_so_far) # Convert logits to logprobs. # shape: (batch_size * beam_size, vocab_size) class_logprobs = F.log_softmax(class_logits, dim=1) # Set logprobs of last predicted tokens as high negative value to avoid # repetition in caption. for index in range(batch_size * self.beam_size): class_logprobs[index, predictions_so_far[index, -1]] = -10000 # shape: (batch_size * beam_size, num_classes) last_predictions_expanded = last_predictions.unsqueeze(-1).expand( batch_size * self.beam_size, num_classes ) # Here we are finding any beams where we predicted the end token in # the previous timestep and replacing the distribution with a # one-hot distribution, forcing the beam to predict the end token # this timestep as well. # shape: (batch_size * beam_size, num_classes) cleaned_logprobs = torch.where( last_predictions_expanded == self._eos_index, logprobs_after_end, class_logprobs, ) # shape (both): (batch_size * beam_size, per_node_beam_size) top_logprobs, predicted_classes = cleaned_logprobs.topk( self.per_node_beam_size ) # Here we expand the last log probs to `(batch_size * beam_size, # per_node_beam_size)` so that we can add them to the current log # probs for this timestep. This lets us maintain the log # probability of each element on the beam. # shape: (batch_size * beam_size, per_node_beam_size) expanded_last_logprobs = ( last_logprobs.unsqueeze(2) .expand(batch_size, self.beam_size, self.per_node_beam_size) .reshape(batch_size * self.beam_size, self.per_node_beam_size) ) # shape: (batch_size * beam_size, per_node_beam_size) summed_top_logprobs = top_logprobs + expanded_last_logprobs # shape: (batch_size, beam_size * per_node_beam_size) reshaped_summed = summed_top_logprobs.reshape( batch_size, self.beam_size * self.per_node_beam_size ) # shape: (batch_size, beam_size * per_node_beam_size) reshaped_predicted_classes = predicted_classes.reshape( batch_size, self.beam_size * self.per_node_beam_size ) # Append the predictions to the current beam. reshaped_beam = ( predictions.view(batch_size * self.beam_size, 1, -1) .repeat(1, self.per_node_beam_size, 1) .reshape(batch_size, self.beam_size * self.per_node_beam_size, -1) ) reshaped_beam = torch.cat( [reshaped_beam, reshaped_predicted_classes.unsqueeze(-1)], dim=-1 ) # Keep only the top `beam_size` beam indices. # shape: (batch_size, beam_size), (batch_size, beam_size) restricted_beam_logprobs, restricted_beam_indices = reshaped_summed.topk( self.beam_size ) predictions = reshaped_beam.gather( 1, restricted_beam_indices.unsqueeze(-1).repeat( 1, 1, reshaped_beam.shape[-1] ), ) # shape: (batch_size, beam_size) last_logprobs = restricted_beam_logprobs if not torch.isfinite(last_logprobs).all(): warnings.warn( "Infinite log probs encountered. Some final captions may not " "make sense. This can happen when the beam size is larger than" " the number of valid (non-zero probability) transitions that " "the step function produces.", RuntimeWarning, ) # Optionally select best beam and its logprobs. if only_return_best: # shape: (batch_size, sequence_length) predictions = predictions[:, 0, :] last_logprobs = last_logprobs[:, 0] return predictions, last_logprobs ================================================ FILE: virtex/utils/checkpointing.py ================================================ import copy import pathlib from typing import Any, Dict, List, Optional from loguru import logger import torch from torch import nn import virtex.utils.distributed as dist class CheckpointManager: r""" A helper class to periodically serialize models and other checkpointable objects (optimizers, LR schedulers etc., which implement ``state_dict`` method) during training, and optionally record best performing checkpoint based on an observed metric. .. note:: For :class:`~torch.nn.parallel.DistributedDataParallel` objects, ``state_dict`` of internal model is serialized. .. note:: The observed metric for keeping best checkpoint is assumed "higher is better", flip the sign if otherwise. Args: serialization_dir: Path to a directory to save checkpoints. keep_recent: Number of recent ``k`` checkpoints to keep on disk. Older checkpoints will be removed. Set to a very large value for keeping all checkpoints. checkpointables: Keyword arguments with any checkpointable objects, for example: model, optimizer, learning rate scheduler. Examples: >>> model = torch.nn.Linear(10, 2) >>> optimizer = torch.optim.Adam(model.parameters()) >>> ckpt_manager = CheckpointManager("/tmp", model=model, optimizer=optimizer) >>> num_epochs = 20 >>> for epoch in range(num_epochs): ... train(model) ... val_loss = validate(model) ... ckpt_manager.step(- val_loss, epoch) """ def __init__( self, serialization_dir: str = "/tmp", keep_recent: int = 200, **checkpointables: Any, ): self.serialization_dir = pathlib.Path(serialization_dir) self.keep_recent = keep_recent # Shallow copy, keeps references to tensors as original objects. self.checkpointables = copy.copy(checkpointables) # Initialize members to hold state dict of best checkpoint and its # performance. self._best_metric: float = -1e-12 self._best_ckpt: Dict[str, Any] = {} # Keep epoch/iteration numbers of recently saved 'k' checkpoints. self._recent_iterations: List[int] = [] def step(self, iteration: int, metric: Optional[float] = None): r""" Serialize checkpoint and update best checkpoint based on metric. Keys in serialized checkpoint match those in :attr:`checkpointables`. Args: iteration: Current training iteration. Will be saved with other checkpointables. metric: Observed metric (higher is better) for keeping track of the best checkpoint. If this is ``None``, best chckpoint will not be recorded/updated. """ checkpointable_state_dict: Dict[str, Any] = self._state_dict() # We also checkpoint current iteration. checkpointable_state_dict["iteration"] = iteration # Update the best checkpoint based on metric, if provided. if metric is not None and metric > self._best_metric: self._best_metric = metric self._best_ckpt = copy.copy(checkpointable_state_dict) # Serialize checkpoint corresponding to current iteration. torch.save( checkpointable_state_dict, self.serialization_dir / f"checkpoint_{iteration}.pth", ) if self._best_metric != -1e-12: # Serialize best performing checkpoint observed so far. torch.save( self._best_ckpt, self.serialization_dir / "checkpoint_best.pth" ) # Remove earliest checkpoint if there are more on disk. self._recent_iterations.append(iteration) if len(self._recent_iterations) > self.keep_recent: self.remove_earliest_checkpoint() def _state_dict(self): r"""Return a dict containing state dict of all checkpointables.""" __state_dict: Dict[str, Any] = {} for key in self.checkpointables: if isinstance( self.checkpointables[key], nn.parallel.DistributedDataParallel ): __state_dict[key] = self.checkpointables[key].module.state_dict() else: __state_dict[key] = self.checkpointables[key].state_dict() return __state_dict def remove_earliest_checkpoint(self): r"""Remove earliest serialized checkpoint from disk.""" earliest_iteration = self._recent_iterations.pop(0) (self.serialization_dir / f"checkpoint_{earliest_iteration}.pth").unlink() def load(self, checkpoint_path: str): r""" Load a serialized checkpoint from a path. This method will try to find each of :attr:`checkpointables` in the file and load its state dict. Since our checkpointables are held as references, this method does not return them. Args: checkpoint_path: Path to a checkpoint serialized by :meth:`step`. Returns: Iteration corresponding to the loaded checkpoint. Useful for resuming training. This will be -1 in case of best checkpoint, or if info does not exist. """ # Each process will log a message after loading checkpoint. rank = dist.get_rank() logger.info(f"Rank {rank}: Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") iteration = checkpoint.pop("iteration", -1) # Keep flags of all checkpointables to lo which ones were not loaded. is_loaded = {key: False for key in self.checkpointables} # Load each checkpointable from checkpoint. for key in checkpoint: if key in self.checkpointables: logger.info(f"Rank {rank}: Loading {key} from {checkpoint_path}") if isinstance( self.checkpointables[key], nn.parallel.DistributedDataParallel ): self.checkpointables[key].module.load_state_dict(checkpoint[key]) else: self.checkpointables[key].load_state_dict(checkpoint[key]) is_loaded[key] = True else: logger.info(f"Rank {rank}: {key} not found in `checkpointables`.") not_loaded: List[str] = [key for key in is_loaded if not is_loaded[key]] if len(not_loaded) > 0: logger.info( f"Rank {rank}: Checkpointables not found in file: {not_loaded}" ) return iteration ================================================ FILE: virtex/utils/common.py ================================================ import argparse import os import random import sys from loguru import logger import numpy as np import torch from virtex.config import Config import virtex.utils.distributed as dist def cycle(dataloader, device, start_iteration: int = 0): r""" A generator to yield batches of data from dataloader infinitely. Internally, it sets the ``epoch`` for dataloader sampler to shuffle the examples. One may optionally provide the starting iteration to make sure the shuffling seed is different and continues naturally. """ iteration = start_iteration while True: if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler): # Set the `epoch` of DistributedSampler as current iteration. This # is a way of determinisitic shuffling after every epoch, so it is # just a seed and need not necessarily be the "epoch". logger.info(f"Beginning new epoch, setting shuffle seed {iteration}") dataloader.sampler.set_epoch(iteration) for batch in dataloader: for key in batch: batch[key] = batch[key].to(device) yield batch iteration += 1 def common_setup(_C: Config, _A: argparse.Namespace, job_type: str = "pretrain"): r""" Setup common stuff at the start of every pretraining or downstream evaluation job, all listed here to avoid code duplication. Basic steps: 1. Fix random seeds and other PyTorch flags. 2. Set up a serialization directory and loggers. 3. Log important stuff such as config, process info (useful during distributed training). 4. Save a copy of config to serialization directory. .. note:: It is assumed that multiple processes for distributed training have already been launched from outside. Functions from :mod:`virtex.utils.distributed` module ae used to get process info. Args: _C: Config object with all the parameters. _A: Argparse command line arguments. job_type: Type of job for which setup is to be done; one of ``{"pretrain", "downstream"}``. """ # Get process rank and world size (assuming distributed is initialized). RANK = dist.get_rank() WORLD_SIZE = dist.get_world_size() # For reproducibility - refer https://pytorch.org/docs/stable/notes/randomness.html torch.manual_seed(_C.RANDOM_SEED) torch.backends.cudnn.deterministic = _C.CUDNN_DETERMINISTIC torch.backends.cudnn.benchmark = _C.CUDNN_BENCHMARK random.seed(_C.RANDOM_SEED) np.random.seed(_C.RANDOM_SEED) # Create serialization directory and save config in it. os.makedirs(_A.serialization_dir, exist_ok=True) _C.dump(os.path.join(_A.serialization_dir, f"{job_type}_config.yaml")) # Remove default logger, create a logger for each process which writes to a # separate log-file. This makes changes in global scope. logger.remove(0) if dist.get_world_size() > 1: logger.add( os.path.join(_A.serialization_dir, f"log-rank{RANK}.txt"), format="{time} {level} {message}", ) # Add a logger for stdout only for the master process. if dist.is_master_process(): logger.add( sys.stdout, format="{time}: {message}", colorize=True ) # Print process info, config and args. logger.info(f"Rank of current process: {RANK}. World size: {WORLD_SIZE}") logger.info(str(_C)) logger.info("Command line args:") for arg in vars(_A): logger.info("{:<20}: {}".format(arg, getattr(_A, arg))) def common_parser(description: str = "") -> argparse.ArgumentParser: r""" Create an argument parser some common arguments useful for any pretraining or downstream evaluation scripts. Args: description: Description to be used with the argument parser. Returns: A parser object with added arguments. """ parser = argparse.ArgumentParser(description=description) # fmt: off parser.add_argument( "--config", metavar="FILE", help="Path to a pretraining config file." ) parser.add_argument( "--config-override", nargs="*", default=[], help="A list of key-value pairs to modify pretraining config params.", ) parser.add_argument( "--serialization-dir", default="/tmp/virtex", help="Path to a directory to serialize checkpoints and save job logs." ) group = parser.add_argument_group("Compute resource management arguments.") group.add_argument( "--cpu-workers", type=int, default=0, help="Number of CPU workers per GPU to use for data loading.", ) group.add_argument( "--num-machines", type=int, default=1, help="Number of machines used in distributed training." ) group.add_argument( "--num-gpus-per-machine", type=int, default=0, help="""Number of GPUs per machine with IDs as (0, 1, 2 ...). Set as zero for single-process CPU training.""", ) group.add_argument( "--machine-rank", type=int, default=0, help="""Rank of the machine, integer in [0, num_machines). Default 0 for training with a single machine.""", ) group.add_argument( "--dist-url", default=f"tcp://127.0.0.1:23456", help="""URL of the master process in distributed training, it defaults to localhost for single-machine training.""", ) # fmt: on return parser ================================================ FILE: virtex/utils/distributed.py ================================================ r""" A collection of common utilities for distributed training. These are a bunch of wrappers over utilities from :mod:`torch.distributed` module, but they do not raise exceptions in absence of distributed training / CPU-only training, and fall back to sensible default behavior. """ from typing import Callable, Dict, Tuple, Union from loguru import logger import torch from torch import distributed as dist from torch import multiprocessing as mp def launch( job_fn: Callable, num_machines: int = 1, num_gpus_per_machine: int = 1, machine_rank: int = 0, dist_url: str = "tcp://127.0.0.1:23456", args=(), ): r""" Launch a job in a distributed fashion: given ``num_machines`` machines, each with ``num_gpus_per_machine`` GPUs, this utility will launch one process per GPU. This wrapper uses :func:`torch.multiprocessing.spawn`. The user has to launch one job on each machine, manually specifying a machine rank (incrementing integers from 0), this utility will adjust process ranks per machine. One process on ``machine_rank = 0`` will be refered as the *master process*, and the IP + a free port on this machine will serve as the distributed process communication URL. Default arguments imply one machine with one GPU, and communication URL as ``localhost``. .. note:: This utility assumes same number of GPUs per machine with IDs as ``(0, 1, 2 ...)``. If you do not wish to use all GPUs on a machine, set ``CUDA_VISIBLE_DEVICES`` environment variable (for example, ``CUDA_VISIBLE_DEVICES=5,6``, which restricts to GPU 5 and 6 and re-assigns their IDs to 0 and 1 in this job scope). Args: job_fn: A callable object to launch. Pass your main function doing training, validation etc. here. num_machines: Number of machines, each with ``num_gpus_per_machine`` GPUs. num_gpus_per_machine: Number of GPUs per machine, with IDs as ``(0, 1, 2 ...)``. machine_rank: A manually specified rank of the machine, serves as a unique identifier and useful for assigning global ranks to processes. dist_url: Disributed process communication URL as ``tcp://x.x.x.x:port``. Set this as the IP (and a free port) of machine with rank 0. args: Arguments to be passed to ``job_fn``. """ assert ( torch.cuda.is_available() ), "CUDA not available, Cannot launch distributed processes." world_size = num_machines * num_gpus_per_machine # Spawn ``num_gpus_per_machine``` processes per machine, and provide # "local process rank" (GPU ID) as the first arg to ``_dist_worker``. # fmt: off if world_size > 1: mp.spawn( _job_worker, nprocs=num_gpus_per_machine, args=( job_fn, world_size, num_gpus_per_machine, machine_rank, dist_url, args ), daemon=False, ) else: # Default to single machine, single GPU, with ID 0. _job_worker(0, job_fn, 1, 1, 0, dist_url, args) # fmt: on def _job_worker( local_rank: int, job_fn: Callable, world_size: int, num_gpus_per_machine: int, machine_rank: int, dist_url: str, args: Tuple, ): r""" Single distibuted process worker. This should never be used directly, only used by :func:`launch`. """ # Adjust global rank of process based on its machine rank. global_rank = machine_rank * num_gpus_per_machine + local_rank try: dist.init_process_group( backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank, ) except Exception as e: logger.error(f"Error launching processes, dist URL: {dist_url}") raise e synchronize() # Set GPU ID for each process according to its rank. torch.cuda.set_device(local_rank) job_fn(*args) def synchronize() -> None: r"""Synchronize (barrier) all processes in a process group.""" if dist.is_initialized(): dist.barrier() def get_world_size() -> int: r"""Return number of processes in the process group, each uses 1 GPU.""" return dist.get_world_size() if dist.is_initialized() else 1 def get_rank() -> int: r"""Return rank of current process in the process group.""" return dist.get_rank() if dist.is_initialized() else 0 def is_master_process() -> bool: r""" Check whether current process is the master process. This check is useful to restrict logging and checkpointing to master process. It will always return ``True`` for single machine, single GPU execution. """ return get_rank() == 0 def average_across_processes(t: Union[torch.Tensor, Dict[str, torch.Tensor]]): r""" Averages a tensor, or a dict of tensors across all processes in a process group. Objects in all processes will finally have same mean value. .. note:: Nested dicts of tensors are not supported. Args: t: torch.Tensor or Dict[str, torch.Tensor] A tensor or dict of tensors to average across processes. """ if dist.is_initialized(): if isinstance(t, torch.Tensor): dist.all_reduce(t, op=dist.ReduceOp.SUM) t /= get_world_size() elif isinstance(t, dict): for k in t: dist.all_reduce(t[k], op=dist.ReduceOp.SUM) t[k] /= dist.get_world_size() def gpu_mem_usage() -> int: r""" Return gpu memory usage (in megabytes). If not using GPU, return 0 without raising any exceptions. """ if torch.cuda.is_available(): # This will be in bytes, so we divide by (1024 * 1024). return torch.cuda.max_memory_allocated() // 1048576 else: return 0 ================================================ FILE: virtex/utils/metrics.py ================================================ r""" This module is a collection of metrics commonly used during pretraining and downstream evaluation. Two main classes here are: - :class:`TopkAccuracy` used for ImageNet linear classification evaluation. - :class:`CocoCaptionsEvaluator` used for caption evaluation (CIDEr and SPICE). Parts of this module (:meth:`tokenize`, :meth:`cider` and :meth:`spice`) are adapted from `coco-captions evaluation code `_. """ from collections import defaultdict import json import os from subprocess import Popen, PIPE, check_call import tempfile from typing import Any, Dict, List import numpy as np import torch class TopkAccuracy: r""" Top-K classification accuracy. This class can accumulate per-batch accuracy that can be retrieved at the end of evaluation. Targets and predictions are assumed to be integers (long tensors). If used in :class:`~torch.nn.parallel.DistributedDataParallel`, results need to be aggregated across GPU processes outside this class. Args: k: ``k`` for computing Top-K accuracy. """ def __init__(self, k: int = 1): self._k = k self.reset() def reset(self): self.num_total = 0.0 self.num_correct = 0.0 def __call__(self, predictions: torch.Tensor, ground_truth: torch.Tensor): r""" Record the accuracy of current batch of predictions and ground-truth. Args: predictions: Model predictions - logits or probabilities. Tensor of shape ``(num_classes, )`` (not batched) or ``(B, num_classes)``. ground_truth: Ground-truth integer labels. A scalar tensor or a batch tensor of shape ``(B, )`` with values in ``[0, num_classes-1]``. Returns: Accuracy (in percentage) so far. """ # Get top-K predictions (based on scores). if self._k == 1: topk_preds = predictions.max(-1)[1].unsqueeze(-1) else: topk_preds = predictions.topk(min(self._k, predictions.shape[-1]), -1)[1] correct = topk_preds.eq(ground_truth.unsqueeze(-1)).float() self.num_total += ground_truth.numel() self.num_correct += correct.sum() return self.get_result() def get_result(self): # Prevent division by zero. return self.num_correct / (self.num_total + 1e-12) * 100 class CocoCaptionsEvaluator: r"""A helper class to evaluate caption predictions in COCO format. This uses :meth:`cider` and :meth:`spice` which exactly follow original COCO Captions evaluation protocol. Args: gt_annotations_path: Path to ground truth annotations in COCO format (typically this would be COCO Captions ``val2017`` split). """ def __init__(self, gt_annotations_path: str): gt_annotations = json.load(open(gt_annotations_path))["annotations"] # Keep a mapping from image id to a list of captions. self.ground_truth: Dict[int, List[str]] = defaultdict(list) for ann in gt_annotations: self.ground_truth[ann["image_id"]].append(ann["caption"]) self.ground_truth = tokenize(self.ground_truth) def evaluate(self, preds: List[Dict[str, Any]]) -> Dict[str, float]: r"""Compute CIDEr and SPICE scores for predictions. Args: preds: List of per instance predictions in COCO Captions format: ``[ {"image_id": int, "caption": str} ...]``. Returns: Computed metrics; a dict with keys ``{"CIDEr", "SPICE"}``. """ if isinstance(preds, str): preds = json.load(open(preds)) res = {ann["image_id"]: [ann["caption"]] for ann in preds} res = tokenize(res) # Remove IDs from predictions which are not in GT. common_image_ids = self.ground_truth.keys() & res.keys() res = {k: v for k, v in res.items() if k in common_image_ids} # Add dummy entries for IDs absent in preds, but present in GT. for k in self.ground_truth: res[k] = res.get(k, [""]) cider_score = cider(res, self.ground_truth) spice_score = spice(res, self.ground_truth) return {"CIDEr": 100 * cider_score, "SPICE": 100 * spice_score} def tokenize(image_id_to_captions: Dict[int, List[str]]) -> Dict[int, List[str]]: r""" Given a mapping of image id to a list of corrsponding captions, tokenize captions in place according to Penn Treebank Tokenizer. This method assumes the presence of Stanford CoreNLP JAR file in directory of this module. """ # Path to the Stanford CoreNLP JAR file. CORENLP_JAR = ( "assets/stanford-corenlp-full-2014-08-27/stanford-corenlp-3.4.1.jar" ) # Prepare data for Tokenizer: write captions to a text file, one per line. image_ids = [k for k, v in image_id_to_captions.items() for _ in range(len(v))] sentences = "\n".join( [c.replace("\n", " ") for k, v in image_id_to_captions.items() for c in v] ) tmp_file = tempfile.NamedTemporaryFile(delete=False) tmp_file.write(sentences.encode()) tmp_file.close() # fmt: off # Tokenize sentences. We use the JAR file for tokenization. command = [ "java", "-cp", CORENLP_JAR, "edu.stanford.nlp.process.PTBTokenizer", "-preserveLines", "-lowerCase", tmp_file.name ] tokenized_captions = ( Popen(command, cwd=os.path.dirname(os.path.abspath(__file__)), stdout=PIPE) .communicate(input=sentences.rstrip())[0] .decode() .split("\n") ) # fmt: on os.remove(tmp_file.name) # Map tokenized captions back to their image IDs. # Punctuations to be removed from the sentences (PTB style)). # fmt: off PUNCTS = [ "''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", ".", "?", "!", ",", ":", "-", "--", "...", ";", ] # fmt: on image_id_to_tokenized_captions: Dict[int, List[str]] = defaultdict(list) for image_id, caption in zip(image_ids, tokenized_captions): image_id_to_tokenized_captions[image_id].append( " ".join([w for w in caption.rstrip().split(" ") if w not in PUNCTS]) ) return image_id_to_tokenized_captions def cider( predictions: Dict[int, List[str]], ground_truth: Dict[int, List[str]], n: int = 4, sigma: float = 6.0, ) -> float: r"""Compute CIDEr score given ground truth captions and predictions.""" # ------------------------------------------------------------------------- def to_ngrams(sentence: str, n: int = 4): r"""Convert a sentence into n-grams and their counts.""" words = sentence.split() counts = defaultdict(int) # type: ignore for k in range(1, n + 1): for i in range(len(words) - k + 1): ngram = tuple(words[i : i + k]) counts[ngram] += 1 return counts def counts2vec(cnts, document_frequency, log_reference_length): r"""Function maps counts of ngram to vector of tfidf weights.""" vec = [defaultdict(float) for _ in range(n)] length = 0 norm = [0.0 for _ in range(n)] for (ngram, term_freq) in cnts.items(): df = np.log(max(1.0, document_frequency[ngram])) # tf (term_freq) * idf (precomputed idf) for n-grams vec[len(ngram) - 1][ngram] = float(term_freq) * ( log_reference_length - df ) # Compute norm for the vector: will be used for computing similarity norm[len(ngram) - 1] += pow(vec[len(ngram) - 1][ngram], 2) if len(ngram) == 2: length += term_freq norm = [np.sqrt(nn) for nn in norm] return vec, norm, length def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): r"""Compute the cosine similarity of two vectors.""" delta = float(length_hyp - length_ref) val = np.array([0.0 for _ in range(n)]) for nn in range(n): for (ngram, count) in vec_hyp[nn].items(): val[nn] += ( min(vec_hyp[nn][ngram], vec_ref[nn][ngram]) * vec_ref[nn][ngram] ) val[nn] /= (norm_hyp[nn] * norm_ref[nn]) or 1 val[nn] *= np.e ** (-(delta ** 2) / (2 * sigma ** 2)) return val # ------------------------------------------------------------------------- ctest = [to_ngrams(predictions[image_id][0]) for image_id in ground_truth] crefs = [ [to_ngrams(gt) for gt in ground_truth[image_id]] for image_id in ground_truth ] # Build document frequency and compute IDF. document_frequency = defaultdict(float) for refs in crefs: # refs, k ref captions of one image for ngram in set([ngram for ref in refs for (ngram, count) in ref.items()]): document_frequency[ngram] += 1 # Compute log reference length. log_reference_length = np.log(float(len(crefs))) scores = [] for test, refs in zip(ctest, crefs): # Compute vector for test captions. vec, norm, length = counts2vec( test, document_frequency, log_reference_length ) # Compute vector for ref captions. score = np.array([0.0 for _ in range(n)]) for ref in refs: vec_ref, norm_ref, length_ref = counts2vec( ref, document_frequency, log_reference_length ) score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) score_avg = np.mean(score) score_avg /= len(refs) score_avg *= 10.0 scores.append(score_avg) return np.mean(scores) def spice( predictions: Dict[int, List[str]], ground_truth: Dict[int, List[str]] ) -> float: r"""Compute SPICE score given ground truth captions and predictions.""" # Prepare temporary input file for the SPICE scorer. input_data = [ { "image_id": image_id, "test": predictions[image_id][0], "refs": ground_truth[image_id], } for image_id in ground_truth ] # Create a temporary directory and dump input file to SPICE. temp_dir = tempfile.mkdtemp() INPUT_PATH = os.path.join(temp_dir, "input_file.json") OUTPUT_PATH = os.path.join(temp_dir, "output_file.json") json.dump(input_data, open(INPUT_PATH, "w")) # fmt: off # Run the command to execute SPICE jar. CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) SPICE_JAR = f"{CURRENT_DIR}/assets/SPICE-1.0/spice-1.0.jar" CACHE_DIR = f"{CURRENT_DIR}/assets/cache" os.makedirs(CACHE_DIR, exist_ok=True) spice_cmd = [ "java", "-jar", "-Xmx8G", SPICE_JAR, INPUT_PATH, "-cache", CACHE_DIR, "-out", OUTPUT_PATH, "-subset", "-silent", ] check_call(spice_cmd, cwd=CURRENT_DIR) # fmt: on # Read and process results results = json.load(open(OUTPUT_PATH)) image_id_to_scores = {item["image_id"]: item["scores"] for item in results} spice_scores = [ np.array(item["scores"]["All"]["f"]).astype(float) for item in results ] return np.mean(spice_scores) ================================================ FILE: virtex/utils/nucleus_sampling.py ================================================ r""" Nucleus Sampling was introduced in the paper `The Curious Case of Neural Text Degeneration `_. If you take it from here, make sure to cite them: .. code-block:: text @inproceedings{, title={The Curious Case of Neural Text Degeneration}, author={Ari Holtzman and Jan Buys and Li Du and Maxwell Forbes and Yejin Choi}, journal={ICLR}, year={2020} } Some core parts of this code are adapted with minor modifications from Thomas Wolf's gist: https://gist.githubusercontent.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ from typing import Callable, List, Tuple import torch import torch.nn.functional as F class AutoRegressiveNucleusSampling: r""" Implements the nucleus sampling for decoding captions. This class only works for auto-regressive models (Transformer-like), not recurrent models (LSTM-like). Args: eos_index: The index of the end token (``[EOS]``) in vocabulary. max_steps: The maximum number of decoding steps. nucleus_size: Size of top-K nucleus for sampling. """ def __init__( self, eos_index: int, max_steps: int = 50, nucleus_size: float = 0.9, ): super().__init__() self._eos_index = eos_index self.max_steps = max_steps self.nucleus_size = nucleus_size def search( self, start_predictions: torch.Tensor, step: Callable[..., torch.Tensor] ) -> Tuple[torch.Tensor, None]: batch_size = start_predictions.size()[0] # List of `(batch_size, )` tensors. One for each timestep. # This includes the start-of-sentence tokens, unlike the implementation # in `AutoregressiveBeamSearch`. We will remove them in the end. predictions: List[torch.Tensor] = [start_predictions] for timestep in range(self.max_steps): # Get the predictions from last timestep (most recent). # shape: (batch_size, ) last_predictions = predictions[-1] # If every predicted token from the last step is end-of-sentence token, # then we can stop early. if (last_predictions == self._eos_index).all(): break # Combine step predictions made so far into one tensor. This is our # "partial" caption input to the transformer. # shape: (batch_size, timestep + 1) predictions_so_far = torch.stack(predictions).permute(1, 0) # Take a step, get the distribution of logits from next timestep. # shape: (batch_size, num_classes) current_logits = step(predictions_so_far) # Sort logits in descending order to determine the nucleus. sorted_logits, sorted_idx = torch.sort(current_logits, descending=True) # Get cumulative softmax probabilites. For every instance in batch, a # variable amount of tokens (N) will consitute the nucleus. # shape: (batch_size, num_classes) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Determine indices of tokens at the tail of distribution. These will be # removed from the nucleus. sorted_idx_to_remove = cumulative_probs > self.nucleus_size # Shift the indices to the right to keep the first token outside nucleus. sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone() sorted_idx_to_remove[..., 0] = 0 # Set logits to large negative value to avoid sampling them. Iterate over # the batch of examples. for t in range(current_logits.size()[0]): idx_to_remove = sorted_idx[t][sorted_idx_to_remove[t]] current_logits[t][idx_to_remove] = -1e12 # Set logits for last predicted token to a large negative value to # avoid repetition. current_logits[t][last_predictions[t]] = -1e12 # Sample from the filtered distribution. # shape: (batch_size, num_classes) current_probs = F.softmax(current_logits, dim=-1) # shape: (batch_size, ) current_predictions = torch.multinomial(current_probs, 1) current_predictions = current_predictions.view(batch_size) # Set current predicted tokens to be end-of-sentence for instances where # last prediction was also end-of-sentence token. current_predictions[last_predictions == self._eos_index] = self._eos_index predictions.append(current_predictions) # Remove start-of-sentence token from predictions, and collect them together. # shape: (batch_size, max_steps) .. or could be less than max_steps. all_predictions = torch.stack(predictions[1:]).permute(1, 0) # We don't return any logprobs of generated sequence with nucleus sampling, # unlike `AutoregressiveBeamSearch`. return all_predictions, None ================================================ FILE: virtex/utils/timer.py ================================================ import time from typing import Optional class Timer: r""" A simple timer to record time per iteration and ETA of training. ETA is estimated by moving window average with fixed window size. Args: start_from: Iteration from which counting should be started/resumed. total_iterations: Total number of iterations. ETA will not be tracked (will remain "N/A") if this is not provided. window_size: Window size to calculate ETA based on past few iterations. """ def __init__( self, start_from: int = 1, total_iterations: Optional[int] = None, window_size: int = 20, ): # We decrement by 1 because `current_iter` changes increment during # an iteration (for example, will change from 0 -> 1 on iteration 1). self.current_iter = start_from - 1 self.total_iters = total_iterations self._start_time = time.time() self._times = [0.0] * window_size def tic(self) -> None: r"""Start recording time: call at the beginning of iteration.""" self._start_time = time.time() def toc(self) -> None: r"""Stop recording time: call at the end of iteration.""" self._times.append(time.time() - self._start_time) self._times = self._times[1:] self.current_iter += 1 @property def stats(self) -> str: r"""Return a single string with current iteration, time and ETA.""" return ( f"Iter {self.current_iter} | Time: {self._times[-1]:.3f} sec | " f"ETA: {self.eta_hhmm}" ) @property def eta_hhmm(self) -> str: r"""Return ETA in the form of ``hh mm`` string.""" if self.total_iters: avg_time = sum(self._times) / len(self._times) eta_sec = int(avg_time * (self.total_iters - self.current_iter)) return f"{eta_sec // 3600}h {((eta_sec % 3600) // 60):02d}m" else: return "N/A"