master b282e8b49e4e cached
26 files
120.7 KB
33.3k tokens
135 symbols
1 requests
Download .txt
Repository: OpenGVLab/efficient-video-recognition
Branch: master
Commit: b282e8b49e4e
Files: 26
Total size: 120.7 KB

Directory structure:
gitextract_7dyo2vh1/

├── .gitignore
├── README.md
├── checkpoint.py
├── data/
│   └── k400_class_mappings.json
├── main.py
├── model.py
├── scripts/
│   ├── eval_k400_vitb16_16f_dec4x768.sh
│   ├── eval_k400_vitb16_32f_dec4x768.sh
│   ├── eval_k400_vitb16_8f_dec4x768.sh
│   ├── eval_k400_vitl14_16f_dec4x1024.sh
│   ├── eval_k400_vitl14_32f_dec4x1024.sh
│   ├── eval_k400_vitl14_8f_dec4x1024.sh
│   ├── train_k400_vitb16_16f_dec4x768.sh
│   ├── train_k400_vitb16_32f_dec4x768.sh
│   ├── train_k400_vitb16_8f_dec4x768.sh
│   ├── train_k400_vitl14_16f_dec4x1024.sh
│   ├── train_k400_vitl14_32f_dec4x1024.sh
│   └── train_k400_vitl14_8f_dec4x1024.sh
├── video_dataset/
│   ├── __init__.py
│   ├── dataloader.py
│   ├── dataset.py
│   ├── rand_augment.py
│   ├── random_erasing.py
│   └── transform.py
├── vision_transformer.py
└── weight_loaders.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__
*.py[cod]
*$py.class

runs/

video_dataset/io_internal.py

================================================
FILE: README.md
================================================
# Frozen CLIP models are Efficient Video Learners

This is the official implementation of the paper [Frozen CLIP models are Efficient Video Learners](https://arxiv.org/abs/2208.03550)

```
@article{lin2022frozen,
  title={Frozen CLIP Models are Efficient Video Learners},
  author={Lin, Ziyi and Geng, Shijie and Zhang, Renrui and Gao, Peng and de Melo, Gerard and Wang, Xiaogang and Dai, Jifeng and Qiao, Yu and Li, Hongsheng},
  journal={arXiv preprint arXiv:2208.03550},
  year={2022}
}
```

## Introduction

The overall architecture of the EVL framework includes a trainable Transformer decoder, trainable local temporal modules and a pretrained, fixed image backbone
(CLIP is used for instance).

<img src="figs/arch.png" height="300">

Using a fixed backbone significantly saves training time, and we managed to train a ViT-B/16 with 8 frames for 50 epochs in 60 GPU-hours (NVIDIA V100).

Despite with a small training computation and memory consumption, EVL models achieves high performance on Kinetics-400. A comparison with state-of-the-art methods
are as follows

<img src="figs/k400.png" height="300">

## Installation

We tested the released code with the following conda environment

```
conda create -n pt1.9.0cu11.1_official -c pytorch -c conda-forge pytorch=1.9.0=py3.9_cuda11.1_cudnn8.0.5_0 cudatoolkit torchvision av
```

## Data Preparation

We expect that `--train_list_path` and `--val_list_path` command line arguments to be a data list file of the following format
```
<path_1> <label_1>
<path_2> <label_2>
...
<path_n> <label_n>
```
where `<path_i>` points to a video file, and `<label_i>` is an integer between `0` and `num_classes - 1`.
`--num_classes` should also be specified in the command line argument.

Additionally, `<path_i>` might be a relative path when `--data_root` is specified, and the actual path will be
relative to the path passed as `--data_root`.

The class mappings in the open-source weights are provided at [Kinetics-400 class mappings](data/k400_class_mappings.json)

## Backbone Preparation

CLIP weights need to be downloaded from [CLIP official repo](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/clip.py#L30)
and passed to the `--backbone_path` command line argument.

## Script Usage

Training and evaliation scripts are provided in the scripts folder.
Scripts should be ready to run once the environment is setup and 
`--backbone_path`, `--train_list_path` and `--val_list_path` are replaced with your own paths.

For other command line arguments please see the help message for usage.

## Kinetics-400 Main Results

This is a re-implementation for open-source use.
We are still re-running some models, and their scripts, weights and logs will be released later.
In the following table we report the re-run accuracy, which may be slightly different from the original paper (typically +/-0.1%)

| Backbone | Decoder Layers | #frames x stride | top-1 | top-5 | Script | Model | Log |
| - | - | - | - | - | - | - | - |
| ViT-B/16 | 4 | 8 x 16 | 82.8 | 95.8 | [script](scripts/train_k400_vitb16_8f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1DoGjvDdkJoSa9i-wq1lh6QoEZIa4xTB3/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1-9vgsXMpnWBI9MxQV7SSQhkPfLomoYY3/view?usp=sharing) |
| ViT-B/16 | 4 | 16 x 16 | 83.7 | 96.2 | [script](scripts/train_k400_vitb16_16f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1dax4qUIOEI_QzYXv31J-87cDkonQetVQ/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1l2ivY28jUpwSmafQZvwtUo7tvm42i0PL/view?usp=sharing) |
| ViT-B/16 | 4 | 32 x 8 | 84.3 | 96.6 | [script](scripts/train_k400_vitb16_32f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1fzFM5pD39Kfp8xRAJuWaXR9RALLmnoeU/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1X1ZOdSCxXVeMpNhr_bviNKlRfJa5SMD7/view?usp=sharing) |
| ViT-L/14 | 4 | 8 x 16 | 86.3 | 97.2 | [script](scripts/train_k400_vitl14_8f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1AkdF4CkOVW2uiycCVqCxS397oYxNISAI/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1OJFBmaE_tAwTzG-4i0CLQmhwGnN0psx1/view?usp=sharing) |
| ViT-L/14 | 4 | 16 x 16 | 86.9 | 97.4 | [script](scripts/train_k400_vitl14_16f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1CTV9geLD3HLWzByAQUOf_m0F_g2lE3rg/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1a2iC4tQvjWFMI3UrEv2chuHwVrF6p9YF/view?usp=sharing) |
| ViT-L/14 | 4 | 32 x 8 | 87.7 | 97.6 | [script](scripts/train_k400_vitl14_32f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1zNFNCKwP5owakELlnTCD20cpVQBqgJrB/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1dK7qoz3McYrmfS09FfreXC-LjUM7l0u4/view?usp=sharing) |
| ViT-L/14 (336px) | 4 | 32 x 8 | 87.7 | 97.8 | | | |

## Data Loading Speed

As the training process is fast, video frames are consumed at a very high rate.
For easier installation, the current version uses PyTorch-builtin data loaders.
They are not very efficient and can become a bottleneck when using ViT-B as backbones.
We provide a `--dummy_dataset` option to bypass actual video decoding for training speed measurement. 
The model accuracy should not be affected. 
Our internal data loader is pure C++-based and does not bottleneck training by much on a machine with 2x Xeon Gold 6148 CPUs and 4x V100 GPUs.


## Acknowledgements

The data loader code is modified from [PySlowFast](https://github.com/facebookresearch/SlowFast). Thanks for their awesome work!


================================================
FILE: checkpoint.py
================================================
#!/usr/bin/env python

import argparse
import os

import torch
import torch.distributed as dist


def setup_arg_parser(parser: argparse.ArgumentParser):
    parser.add_argument('--checkpoint_dir', type=str,
                        help='checkpoint output path')
    parser.add_argument('--auto_resume', action='store_true',
                        help='auto resume from the last checkpoint from checkpoint_dir')
    parser.add_argument('--resume_path', type=str,
                        help='resume from manually specified checkpoint file, overriding auto_resume')
    parser.add_argument('--pretrain', type=str,
                        help='path to pretrained weights. will NOT override auto_resume of resume_path, '
                             'load optimizer state or enforce strict matching of checkpoint and model weights.')


def _find_autoresume_path(args: argparse.Namespace):
    print('Trying to auto resume from path:', args.checkpoint_dir)

    if os.path.isdir(args.checkpoint_dir):
        checkpoint_files = [x for x in os.listdir(args.checkpoint_dir) if x.startswith('checkpoint-') and x.endswith('.pth')]
        checkpoint_iters = []
        for x in checkpoint_files:
            try:
                x = x[len('checkpoint-'): -len('.pth')]
                x = int(x)
            except ValueError:
                continue
            checkpoint_iters.append(x)
    else:
        checkpoint_iters = []

    if len(checkpoint_iters) == 0:
        print('Did not find a valid checkpoint file.')
    else:
        checkpoint_iters.sort()
        args.resume_path = os.path.join(args.checkpoint_dir, 'checkpoint-%d.pth' % checkpoint_iters[-1])
        print(f'Found {len(checkpoint_iters)} checkpoint file(s).')


def resume_from_checkpoint(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    lr_sched: torch.optim.lr_scheduler._LRScheduler,
    loss_scaler: torch.cuda.amp.grad_scaler.GradScaler,
    args: argparse.Namespace,
) -> int:
    if args.pretrain is not None:
        print(f'Loading pretrain model: {args.pretrain}')
        ckpt = torch.load(args.pretrain, map_location='cpu')
        print(model.load_state_dict(ckpt['model'], strict=False))

    # returns resume_step on successful resume, or 0 otherwise.
    if args.auto_resume and args.resume_path is None:
        _find_autoresume_path(args)
    
    if args.resume_path is None:
        print('Not resuming from a checkpoint.')
        return 0
    else:
        print(f'Resuming from checkpoint file {args.resume_path}')
        ckpt = torch.load(args.resume_path, map_location='cpu')
        model.load_state_dict(ckpt['model'], strict=True)
        if 'optimizer' in ckpt:
            optimizer.load_state_dict(ckpt['optimizer'])
            lr_sched.load_state_dict(ckpt['lr_sched'])
            loss_scaler.load_state_dict(ckpt['loss_scaler'])
            return ckpt['next_step']
        else:
            print('Optimizer state is NOT found in checkpoint.')
            return 0


def save_checkpoint(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    lr_sched: torch.optim.lr_scheduler._LRScheduler,
    loss_scaler: torch.cuda.amp.grad_scaler.GradScaler,
    next_step: int,
    args: argparse.Namespace,
):
    if args.checkpoint_dir is None:
        return

    if not os.path.isdir(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)
    
    to_save = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_sched': lr_sched.state_dict(),
        'loss_scaler': loss_scaler.state_dict(),
        'next_step': next_step,
    }
    torch.save(to_save, os.path.join(args.checkpoint_dir, f'checkpoint-{next_step}.pth'))


================================================
FILE: data/k400_class_mappings.json
================================================
[
  "abseiling",
  "air drumming",
  "answering questions",
  "applauding",
  "applying cream",
  "archery",
  "arm wrestling",
  "arranging flowers",
  "assembling computer",
  "auctioning",
  "baby waking up",
  "baking cookies",
  "balloon blowing",
  "bandaging",
  "barbequing",
  "bartending",
  "beatboxing",
  "bee keeping",
  "belly dancing",
  "bench pressing",
  "bending back",
  "bending metal",
  "biking through snow",
  "blasting sand",
  "blowing glass",
  "blowing leaves",
  "blowing nose",
  "blowing out candles",
  "bobsledding",
  "bookbinding",
  "bouncing on trampoline",
  "bowling",
  "braiding hair",
  "breading or breadcrumbing",
  "breakdancing",
  "brush painting",
  "brushing hair",
  "brushing teeth",
  "building cabinet",
  "building shed",
  "bungee jumping",
  "busking",
  "canoeing or kayaking",
  "capoeira",
  "carrying baby",
  "cartwheeling",
  "carving pumpkin",
  "catching fish",
  "catching or throwing baseball",
  "catching or throwing frisbee",
  "catching or throwing softball",
  "celebrating",
  "changing oil",
  "changing wheel",
  "checking tires",
  "cheerleading",
  "chopping wood",
  "clapping",
  "clay pottery making",
  "clean and jerk",
  "cleaning floor",
  "cleaning gutters",
  "cleaning pool",
  "cleaning shoes",
  "cleaning toilet",
  "cleaning windows",
  "climbing a rope",
  "climbing ladder",
  "climbing tree",
  "contact juggling",
  "cooking chicken",
  "cooking egg",
  "cooking on campfire",
  "cooking sausages",
  "counting money",
  "country line dancing",
  "cracking neck",
  "crawling baby",
  "crossing river",
  "crying",
  "curling hair",
  "cutting nails",
  "cutting pineapple",
  "cutting watermelon",
  "dancing ballet",
  "dancing charleston",
  "dancing gangnam style",
  "dancing macarena",
  "deadlifting",
  "decorating the christmas tree",
  "digging",
  "dining",
  "disc golfing",
  "diving cliff",
  "dodgeball",
  "doing aerobics",
  "doing laundry",
  "doing nails",
  "drawing",
  "dribbling basketball",
  "drinking",
  "drinking beer",
  "drinking shots",
  "driving car",
  "driving tractor",
  "drop kicking",
  "drumming fingers",
  "dunking basketball",
  "dying hair",
  "eating burger",
  "eating cake",
  "eating carrots",
  "eating chips",
  "eating doughnuts",
  "eating hotdog",
  "eating ice cream",
  "eating spaghetti",
  "eating watermelon",
  "egg hunting",
  "exercising arm",
  "exercising with an exercise ball",
  "extinguishing fire",
  "faceplanting",
  "feeding birds",
  "feeding fish",
  "feeding goats",
  "filling eyebrows",
  "finger snapping",
  "fixing hair",
  "flipping pancake",
  "flying kite",
  "folding clothes",
  "folding napkins",
  "folding paper",
  "front raises",
  "frying vegetables",
  "garbage collecting",
  "gargling",
  "getting a haircut",
  "getting a tattoo",
  "giving or receiving award",
  "golf chipping",
  "golf driving",
  "golf putting",
  "grinding meat",
  "grooming dog",
  "grooming horse",
  "gymnastics tumbling",
  "hammer throw",
  "headbanging",
  "headbutting",
  "high jump",
  "high kick",
  "hitting baseball",
  "hockey stop",
  "holding snake",
  "hopscotch",
  "hoverboarding",
  "hugging",
  "hula hooping",
  "hurdling",
  "hurling (sport)",
  "ice climbing",
  "ice fishing",
  "ice skating",
  "ironing",
  "javelin throw",
  "jetskiing",
  "jogging",
  "juggling balls",
  "juggling fire",
  "juggling soccer ball",
  "jumping into pool",
  "jumpstyle dancing",
  "kicking field goal",
  "kicking soccer ball",
  "kissing",
  "kitesurfing",
  "knitting",
  "krumping",
  "laughing",
  "laying bricks",
  "long jump",
  "lunge",
  "making a cake",
  "making a sandwich",
  "making bed",
  "making jewelry",
  "making pizza",
  "making snowman",
  "making sushi",
  "making tea",
  "marching",
  "massaging back",
  "massaging feet",
  "massaging legs",
  "massaging person's head",
  "milking cow",
  "mopping floor",
  "motorcycling",
  "moving furniture",
  "mowing lawn",
  "news anchoring",
  "opening bottle",
  "opening present",
  "paragliding",
  "parasailing",
  "parkour",
  "passing American football (in game)",
  "passing American football (not in game)",
  "peeling apples",
  "peeling potatoes",
  "petting animal (not cat)",
  "petting cat",
  "picking fruit",
  "planting trees",
  "plastering",
  "playing accordion",
  "playing badminton",
  "playing bagpipes",
  "playing basketball",
  "playing bass guitar",
  "playing cards",
  "playing cello",
  "playing chess",
  "playing clarinet",
  "playing controller",
  "playing cricket",
  "playing cymbals",
  "playing didgeridoo",
  "playing drums",
  "playing flute",
  "playing guitar",
  "playing harmonica",
  "playing harp",
  "playing ice hockey",
  "playing keyboard",
  "playing kickball",
  "playing monopoly",
  "playing organ",
  "playing paintball",
  "playing piano",
  "playing poker",
  "playing recorder",
  "playing saxophone",
  "playing squash or racquetball",
  "playing tennis",
  "playing trombone",
  "playing trumpet",
  "playing ukulele",
  "playing violin",
  "playing volleyball",
  "playing xylophone",
  "pole vault",
  "presenting weather forecast",
  "pull ups",
  "pumping fist",
  "pumping gas",
  "punching bag",
  "punching person (boxing)",
  "push up",
  "pushing car",
  "pushing cart",
  "pushing wheelchair",
  "reading book",
  "reading newspaper",
  "recording music",
  "riding a bike",
  "riding camel",
  "riding elephant",
  "riding mechanical bull",
  "riding mountain bike",
  "riding mule",
  "riding or walking with horse",
  "riding scooter",
  "riding unicycle",
  "ripping paper",
  "robot dancing",
  "rock climbing",
  "rock scissors paper",
  "roller skating",
  "running on treadmill",
  "sailing",
  "salsa dancing",
  "sanding floor",
  "scrambling eggs",
  "scuba diving",
  "setting table",
  "shaking hands",
  "shaking head",
  "sharpening knives",
  "sharpening pencil",
  "shaving head",
  "shaving legs",
  "shearing sheep",
  "shining shoes",
  "shooting basketball",
  "shooting goal (soccer)",
  "shot put",
  "shoveling snow",
  "shredding paper",
  "shuffling cards",
  "side kick",
  "sign language interpreting",
  "singing",
  "situp",
  "skateboarding",
  "ski jumping",
  "skiing (not slalom or crosscountry)",
  "skiing crosscountry",
  "skiing slalom",
  "skipping rope",
  "skydiving",
  "slacklining",
  "slapping",
  "sled dog racing",
  "smoking",
  "smoking hookah",
  "snatch weight lifting",
  "sneezing",
  "sniffing",
  "snorkeling",
  "snowboarding",
  "snowkiting",
  "snowmobiling",
  "somersaulting",
  "spinning poi",
  "spray painting",
  "spraying",
  "springboard diving",
  "squat",
  "sticking tongue out",
  "stomping grapes",
  "stretching arm",
  "stretching leg",
  "strumming guitar",
  "surfing crowd",
  "surfing water",
  "sweeping floor",
  "swimming backstroke",
  "swimming breast stroke",
  "swimming butterfly stroke",
  "swing dancing",
  "swinging legs",
  "swinging on something",
  "sword fighting",
  "tai chi",
  "taking a shower",
  "tango dancing",
  "tap dancing",
  "tapping guitar",
  "tapping pen",
  "tasting beer",
  "tasting food",
  "testifying",
  "texting",
  "throwing axe",
  "throwing ball",
  "throwing discus",
  "tickling",
  "tobogganing",
  "tossing coin",
  "tossing salad",
  "training dog",
  "trapezing",
  "trimming or shaving beard",
  "trimming trees",
  "triple jump",
  "tying bow tie",
  "tying knot (not on a tie)",
  "tying tie",
  "unboxing",
  "unloading truck",
  "using computer",
  "using remote controller (not gaming)",
  "using segway",
  "vault",
  "waiting in line",
  "walking the dog",
  "washing dishes",
  "washing feet",
  "washing hair",
  "washing hands",
  "water skiing",
  "water sliding",
  "watering plants",
  "waxing back",
  "waxing chest",
  "waxing eyebrows",
  "waxing legs",
  "weaving basket",
  "welding",
  "whistling",
  "windsurfing",
  "wrapping present",
  "wrestling",
  "writing",
  "yawning",
  "yoga",
  "zumba"
]


================================================
FILE: main.py
================================================
#!/usr/bin/env python

import argparse
from datetime import datetime
import builtins

import torch
import torch.distributed as dist

import video_dataset
import checkpoint
from model import EVLTransformer
from video_dataset import dataloader
from weight_loaders import weight_loader_fn_dict
from vision_transformer import vit_presets

def setup_print(is_master: bool):
    """
    This function disables printing when not in master process
    """
    builtin_print = builtins.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            now = datetime.now().time()
            builtin_print('[{}] '.format(now), end='')  # print with time stamp
            builtin_print(*args, **kwargs)

    builtins.print = print


def main():
    parser = argparse.ArgumentParser()
    
    video_dataset.setup_arg_parser(parser)
    checkpoint.setup_arg_parser(parser)

    parser.add_argument('--num_steps', type=int,
                        help='number of training steps')
    parser.add_argument('--eval_only', action='store_true',
                        help='run evaluation only')
    parser.add_argument('--save_freq', type=int, default=5000,
                        help='save a checkpoint every N steps')
    parser.add_argument('--eval_freq', type=int, default=5000,
                        help='evaluate every N steps')
    parser.add_argument('--print_freq', type=int, default=10,
                        help='print log message every N steps')

    parser.add_argument('--backbone', type=str, choices=vit_presets.keys(), default='ViT-B/16-lnpre',
                        help='the backbone variant used to generate image feature maps')
    parser.add_argument('--backbone_path', type=str,
                        help='path to pretrained backbone weights')
    parser.add_argument('--backbone_type', type=str, default='clip', choices=weight_loader_fn_dict.keys(),
                        help='type of backbone weights (used to determine how to convert state_dict from different pretraining codebase)')
    parser.add_argument('--finetune_backbone', action='store_true',
                        help='finetune backbone weights')
    parser.add_argument('--decoder_num_layers', type=int, default=4,
                        help='number of decoder layers')
    parser.add_argument('--decoder_qkv_dim', type=int, default=768,
                        help='q (k, v) projection output dimensions in decoder attention layers')
    parser.add_argument('--decoder_num_heads', type=int, default=12,
                        help='number of heads in decoder attention layers')
    parser.add_argument('--decoder_mlp_factor', type=float, default=4.0,
                        help='expansion factor of feature dimension in the middle of decoder MLPs')
    parser.add_argument('--num_classes', type=int, default=400,
                        help='number of classes')
    parser.add_argument('--cls_dropout', type=float, default=0.5,
                        help='dropout rate applied before the final classification linear projection')
    parser.add_argument('--decoder_mlp_dropout', type=float, default=0.5,
                        help='dropout rate applied in MLP layers in the decoder')
    parser.add_argument('--no_temporal_conv', action='store_false', dest='temporal_conv',
                        help='disable temporal convolution on frame features')
    parser.add_argument('--no_temporal_pos_embed', action='store_false', dest='temporal_pos_embed',
                        help='disable temporal position embeddings added to frame features')
    parser.add_argument('--no_temporal_cross_attention', action='store_false', dest='temporal_cross_attention',
                        help='disable temporal cross attention on frame query and key features')
    parser.set_defaults(temporal_conv=True, temporal_pos_embed=True, temporal_cross_attention=True)

    parser.add_argument('--lr', type=float, default=4e-4,
                        help='learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.05,
                        help='optimizer weight decay')
    parser.add_argument('--disable_fp16', action='store_false', dest='fp16',
                        help='disable fp16 during training or inference')
    parser.set_defaults(fp16=True)

    parser.add_argument('--batch_split', type=int, default=1,
                        help='optionally split the batch into smaller shards and forward/backward one shard '
                             'at a time to avoid out-of-memory error.')

    args = parser.parse_args()

    dist.init_process_group('nccl')
    setup_print(dist.get_rank() == 0)
    cuda_device_id = dist.get_rank() % torch.cuda.device_count()
    torch.cuda.set_device(cuda_device_id)

    model = EVLTransformer(
        backbone_name=args.backbone,
        backbone_type=args.backbone_type,
        backbone_path=args.backbone_path,
        backbone_mode='finetune' if args.finetune_backbone else ('freeze_fp16' if args.fp16 else 'freeze_fp32'),
        decoder_num_layers=args.decoder_num_layers,
        decoder_qkv_dim=args.decoder_qkv_dim,
        decoder_num_heads=args.decoder_num_heads,
        decoder_mlp_factor=args.decoder_mlp_factor,
        num_classes=args.num_classes,
        enable_temporal_conv=args.temporal_conv,
        enable_temporal_pos_embed=args.temporal_pos_embed,
        enable_temporal_cross_attention=args.temporal_cross_attention,
        cls_dropout=args.cls_dropout,
        decoder_mlp_dropout=args.decoder_mlp_dropout,
        num_frames=args.num_frames,
    )
    print(model)
    model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[cuda_device_id], output_device=cuda_device_id,
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps)
    loss_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=args.fp16)
    criterion = torch.nn.CrossEntropyLoss()

    resume_step = checkpoint.resume_from_checkpoint(model, optimizer, lr_sched, loss_scaler, args)

    val_loader = video_dataset.create_val_loader(args)
    if args.eval_only:
        print('Running in eval_only mode.')
        model.eval()
        evaluate(model, val_loader)
        return
    else:
        assert args.train_list_path is not None, 'Train list path must be specified if not in eval_only mode.'
        train_loader = video_dataset.create_train_loader(args, resume_step=resume_step)

    assert len(train_loader) == args.num_steps - resume_step
    batch_st, train_st = datetime.now(), datetime.now()
    for i, (data, labels) in enumerate(train_loader, resume_step):
        data, labels = data.cuda(), labels.cuda()
        data_ed = datetime.now()

        optimizer.zero_grad()

        assert data.size(0) % args.batch_split == 0
        split_size = data.size(0) // args.batch_split
        hit1, hit5, loss_value = 0, 0, 0
        for j in range(args.batch_split):
            data_slice = data[split_size * j: split_size * (j + 1)]
            labels_slice = labels[split_size * j: split_size * (j + 1)]

            with torch.cuda.amp.autocast(args.fp16):
                logits = model(data_slice)
                loss = criterion(logits, labels_slice)
                
            if labels.dtype == torch.long: # no mixup, can calculate accuracy
                hit1 += (logits.topk(1, dim=1)[1] == labels_slice.view(-1, 1)).sum().item()
                hit5 += (logits.topk(5, dim=1)[1] == labels_slice.view(-1, 1)).sum().item()
            loss_value += loss.item() / args.batch_split
            
            loss_scaler.scale(loss / args.batch_split).backward()
        
        loss_scaler.step(optimizer)
        loss_scaler.update()
        lr_sched.step()

        batch_ed = datetime.now()

        if i % args.print_freq == 0:
            sync_tensor = torch.Tensor([loss_value, hit1 / data.size(0), hit5 / data.size(0)]).cuda()
            dist.all_reduce(sync_tensor)
            sync_tensor = sync_tensor.cpu() / dist.get_world_size()
            loss_value, acc1, acc5 = sync_tensor.tolist()

            print(
                f'batch_time: {(batch_ed - batch_st).total_seconds():.3f}  '
                f'data_time: {(data_ed - batch_st).total_seconds():.3f}  '
                f'ETA: {(batch_ed - train_st) / (i - resume_step + 1) * (args.num_steps - i - 1)}  |  '
                f'lr: {optimizer.param_groups[0]["lr"]:.6f}  '
                f'loss: {loss_value:.6f}' + (
                    f'  acc1: {acc1 * 100:.2f}%  acc5: {acc5 * 100:.2f}%' if labels.dtype == torch.long else ''
                )
            )
        
        if (i + 1) % args.eval_freq == 0:
            print('Start model evaluation at step', i + 1)
            model.eval()
            evaluate(model, val_loader)
            model.train()

        if (i + 1) % args.save_freq == 0:
            checkpoint.save_checkpoint(model, optimizer, lr_sched, loss_scaler, i + 1, args)
        
        batch_st = datetime.now()


def evaluate(model: torch.nn.Module, loader: torch.utils.data.DataLoader):
    tot, hit1, hit5 = 0, 0, 0
    eval_st = datetime.now()
    for data, labels in loader:
        data, labels = data.cuda(), labels.cuda()
        assert data.size(0) == 1
        if data.ndim == 6:
            data = data[0] # now the first dimension is number of views

        with torch.no_grad():
            logits = model(data)
            scores = logits.softmax(dim=-1).mean(dim=0)

        tot += 1
        hit1 += (scores.topk(1)[1] == labels).sum().item()
        hit5 += (scores.topk(5)[1] == labels).sum().item()

        if tot % 20 == 0:
            print(f'[Evaluation] num_samples: {tot}  '
                  f'ETA: {(datetime.now() - eval_st) / tot * (len(loader) - tot)}  '
                  f'cumulative_acc1: {hit1 / tot * 100.:.2f}%  '
                  f'cumulative_acc5: {hit5 / tot * 100.:.2f}%')

    sync_tensor = torch.LongTensor([tot, hit1, hit5]).cuda()
    dist.all_reduce(sync_tensor)
    tot, hit1, hit5 = sync_tensor.cpu().tolist()

    print(f'Accuracy on validation set: top1={hit1 / tot * 100:.2f}%, top5={hit5 / tot * 100:.2f}%')


if __name__ == '__main__': main()


================================================
FILE: model.py
================================================
#!/usr/bin/env python

from typing import Dict, Iterable, List, Tuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from vision_transformer import QuickGELU, Attention
from weight_loaders import weight_loader_fn_dict
from vision_transformer import (
    VisionTransformer2D, TransformerDecoderLayer,
    model_to_fp16, vit_presets,
)
        

class TemporalCrossAttention(nn.Module):

    def __init__(
        self,
        spatial_size: Tuple[int, int] = (14, 14),
        feature_dim: int = 768,
    ):
        super().__init__()

        self.spatial_size = spatial_size

        w_size = np.prod([x * 2 - 1 for x in spatial_size])
        self.w1 = nn.Parameter(torch.zeros([w_size, feature_dim]))
        self.w2 = nn.Parameter(torch.zeros([w_size, feature_dim]))

        idx_tensor = torch.zeros([np.prod(spatial_size) for _ in (0, 1)], dtype=torch.long)
        for q in range(np.prod(spatial_size)):
            qi, qj = q // spatial_size[1], q % spatial_size[1]
            for k in range(np.prod(spatial_size)):
                ki, kj = k // spatial_size[1], k % spatial_size[1]
                i_offs = qi - ki + spatial_size[0] - 1
                j_offs = qj - kj + spatial_size[1] - 1
                idx_tensor[q, k] = i_offs * (spatial_size[1] * 2 - 1) + j_offs
        self.idx_tensor = idx_tensor


    def forward_half(self, q: torch.Tensor, k: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        q, k = q[:, :, 1:], k[:, :, 1:] # remove cls token

        assert q.size() == k.size()
        assert q.size(2) == np.prod(self.spatial_size)

        attn = torch.einsum('ntqhd,ntkhd->ntqkh', q / (q.size(-1) ** 0.5), k)
        attn = attn.softmax(dim=-2).mean(dim=-1) # L, L, N, T

        self.idx_tensor = self.idx_tensor.to(w.device)
        w_unroll = w[self.idx_tensor] # L, L, C
        ret = torch.einsum('ntqk,qkc->ntqc', attn, w_unroll)

        return ret


    def forward(self, q: torch.Tensor, k: torch.Tensor):
        N, T, L, H, D = q.size()
        assert L == np.prod(self.spatial_size) + 1

        ret = torch.zeros([N, T, L, self.w1.size(-1)], device='cuda')
        ret[:, 1:, 1:, :] += self.forward_half(q[:, 1:, :, :, :], k[:, :-1, :, :, :], self.w1)
        ret[:, :-1, 1:, :] += self.forward_half(q[:, :-1, :, :, :], k[:, 1:, :, :, :], self.w2)

        return ret


class EVLDecoder(nn.Module):

    def __init__(
        self,
        num_frames: int = 8,
        spatial_size: Tuple[int, int] = (14, 14),
        num_layers: int = 4,
        in_feature_dim: int = 768,
        qkv_dim: int = 768,
        num_heads: int = 12,
        mlp_factor: float = 4.0,
        enable_temporal_conv: bool = True,
        enable_temporal_pos_embed: bool = True,
        enable_temporal_cross_attention: bool = True,
        mlp_dropout: float = 0.5,
    ):
        super().__init__()

        self.enable_temporal_conv = enable_temporal_conv
        self.enable_temporal_pos_embed = enable_temporal_pos_embed
        self.enable_temporal_cross_attention = enable_temporal_cross_attention
        self.num_layers = num_layers

        self.decoder_layers = nn.ModuleList(
            [TransformerDecoderLayer(in_feature_dim, qkv_dim, num_heads, mlp_factor, mlp_dropout) for _ in range(num_layers)]
        )

        if enable_temporal_conv:
            self.temporal_conv = nn.ModuleList(
                [nn.Conv1d(in_feature_dim, in_feature_dim, kernel_size=3, stride=1, padding=1, groups=in_feature_dim) for _ in range(num_layers)]
            )
        if enable_temporal_pos_embed:
            self.temporal_pos_embed = nn.ParameterList(
                [nn.Parameter(torch.zeros([num_frames, in_feature_dim])) for _ in range(num_layers)]
            )
        if enable_temporal_cross_attention:
            self.cross_attention = nn.ModuleList(
                [TemporalCrossAttention(spatial_size, in_feature_dim) for _ in range(num_layers)]
            )

        self.cls_token = nn.Parameter(torch.zeros([in_feature_dim]))


    def _initialize_weights(self):
        nn.init.normal_(self.cls_token, std=0.02)


    def forward(self, in_features: List[Dict[str, torch.Tensor]]):
        N, T, L, C = in_features[0]['out'].size()
        assert len(in_features) == self.num_layers
        x = self.cls_token.view(1, 1, -1).repeat(N, 1, 1)

        for i in range(self.num_layers):
            frame_features = in_features[i]['out']
            
            if self.enable_temporal_conv:
                feat = in_features[i]['out']
                feat = feat.permute(0, 2, 3, 1).contiguous().flatten(0, 1) # N * L, C, T
                feat = self.temporal_conv[i](feat)
                feat = feat.view(N, L, C, T).permute(0, 3, 1, 2).contiguous() # N, T, L, C
                frame_features += feat
            
            if self.enable_temporal_pos_embed:
                frame_features += self.temporal_pos_embed[i].view(1, T, 1, C)
            
            if self.enable_temporal_cross_attention:
                frame_features += self.cross_attention[i](in_features[i]['q'], in_features[i]['k'])

            frame_features = frame_features.flatten(1, 2) # N, T * L, C
            
            x = self.decoder_layers[i](x, frame_features)
        
        return x


class EVLTransformer(nn.Module):

    def __init__(
        self,
        num_frames: int = 8,
        backbone_name: str = 'ViT-B/16',
        backbone_type: str = 'clip',
        backbone_path: str = '',
        backbone_mode: str = 'frozen_fp16',
        decoder_num_layers: int = 4,
        decoder_qkv_dim: int = 768,
        decoder_num_heads: int = 12,
        decoder_mlp_factor: float = 4.0,
        num_classes: int = 400,
        enable_temporal_conv: bool = True,
        enable_temporal_pos_embed: bool = True,
        enable_temporal_cross_attention: bool = True,
        cls_dropout: float = 0.5,
        decoder_mlp_dropout: float = 0.5,
    ):
        super().__init__()

        self.decoder_num_layers = decoder_num_layers

        backbone_config = self._create_backbone(backbone_name, backbone_type, backbone_path, backbone_mode)
        backbone_feature_dim = backbone_config['feature_dim']
        backbone_spatial_size = tuple(x // y for x, y in zip(backbone_config['input_size'], backbone_config['patch_size']))

        self.decoder = EVLDecoder(
            num_frames=num_frames,
            spatial_size=backbone_spatial_size,
            num_layers=decoder_num_layers,
            in_feature_dim=backbone_feature_dim,
            qkv_dim=decoder_qkv_dim,
            num_heads=decoder_num_heads,
            mlp_factor=decoder_mlp_factor,
            enable_temporal_conv=enable_temporal_conv,
            enable_temporal_pos_embed=enable_temporal_pos_embed,
            enable_temporal_cross_attention=enable_temporal_cross_attention,
            mlp_dropout=decoder_mlp_dropout,
        )
        self.proj = nn.Sequential(
            nn.LayerNorm(backbone_feature_dim),
            nn.Dropout(cls_dropout),
            nn.Linear(backbone_feature_dim, num_classes),
        )


    def _create_backbone(
        self,
        backbone_name: str,
        backbone_type: str,
        backbone_path: str,
        backbone_mode: str,
    ) -> dict:
        weight_loader_fn = weight_loader_fn_dict[backbone_type]
        state_dict = weight_loader_fn(backbone_path)

        backbone = VisionTransformer2D(return_all_features=True, **vit_presets[backbone_name])
        backbone.load_state_dict(state_dict, strict=True) # weight_loader_fn is expected to strip unused parameters

        assert backbone_mode in ['finetune', 'freeze_fp16', 'freeze_fp32']

        if backbone_mode == 'finetune':
            self.backbone = backbone
        else:
            backbone.eval().requires_grad_(False)
            if backbone_mode == 'freeze_fp16':
                model_to_fp16(backbone)
            self.backbone = [backbone] # avoid backbone parameter registration

        return vit_presets[backbone_name]


    def _get_backbone(self, x):
        if isinstance(self.backbone, list):
            # freeze backbone
            self.backbone[0] = self.backbone[0].to(x.device)
            return self.backbone[0]
        else:
            # finetune bakbone
            return self.backbone


    def forward(self, x: torch.Tensor):
        backbone = self._get_backbone(x)

        B, C, T, H, W = x.size()
        x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
        features = backbone(x)[-self.decoder_num_layers:]
        features = [
            dict((k, v.float().view(B, T, *v.size()[1:])) for k, v in x.items())
            for x in features
        ]

        x = self.decoder(features)
        x = self.proj(x[:, 0, :])

        return x

================================================
FILE: scripts/eval_k400_vitb16_16f_dec4x768.sh
================================================
#!/usr/bin/env sh

python -u -m torch.distributed.run --nproc_per_node 4 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-B/16-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-B-16.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 768 \
    --decoder_num_heads 12 \
    --num_classes 400 \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 16 \
    --sampling_rate 16 \
    --num_spatial_views 3 \
    --num_temporal_views 1 \
    --resume_path /path/to/checkpoint_release/k400_vitb16_16f_dec4x768.pth \
    --eval_only


================================================
FILE: scripts/eval_k400_vitb16_32f_dec4x768.sh
================================================
#!/usr/bin/env sh

python -u -m torch.distributed.run --nproc_per_node 4 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-B/16-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-B-16.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 768 \
    --decoder_num_heads 12 \
    --num_classes 400 \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 32 \
    --sampling_rate 8 \
    --num_spatial_views 3 \
    --num_temporal_views 1 \
    --resume_path /path/to/checkpoint_release/k400_vitb16_32f_dec4x768.pth \
    --eval_only


================================================
FILE: scripts/eval_k400_vitb16_8f_dec4x768.sh
================================================
#!/usr/bin/env sh

python -u -m torch.distributed.run --nproc_per_node 4 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-B/16-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-B-16.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 768 \
    --decoder_num_heads 12 \
    --num_classes 400 \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 8 \
    --sampling_rate 16 \
    --num_spatial_views 1 \
    --num_temporal_views 3 \
    --resume_path /path/to/checkpoint_release/k400_vitb16_8f_dec4x768.pth \
    --eval_only


================================================
FILE: scripts/eval_k400_vitl14_16f_dec4x1024.sh
================================================
#!/usr/bin/env sh

python -u -m torch.distributed.run --nproc_per_node 4 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-L/14-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-L-14.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 1024 \
    --decoder_num_heads 16 \
    --num_classes 400 \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 16 \
    --sampling_rate 16 \
    --num_spatial_views 3 \
    --num_temporal_views 1 \
    --resume_path /path/to/checkpoint_release/k400_vitl14_16f_dec4x1024.pth \
    --eval_only


================================================
FILE: scripts/eval_k400_vitl14_32f_dec4x1024.sh
================================================
#!/usr/bin/env sh

python -u -m torch.distributed.run --nproc_per_node 4 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-L/14-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-L-14.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 1024 \
    --decoder_num_heads 16 \
    --num_classes 400 \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 32 \
    --sampling_rate 8 \
    --num_spatial_views 3 \
    --num_temporal_views 1 \
    --resume_path /path/to/checkpoint_release/k400_vitl14_32f_dec4x1024.pth \
    --eval_only


================================================
FILE: scripts/eval_k400_vitl14_8f_dec4x1024.sh
================================================
#!/usr/bin/env sh

python -u -m torch.distributed.run --nproc_per_node 4 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-L/14-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-L-14.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 1024 \
    --decoder_num_heads 16 \
    --num_classes 400 \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 8 \
    --sampling_rate 16 \
    --num_spatial_views 1 \
    --num_temporal_views 3 \
    --resume_path /path/to/checkpoint_release/k400_vitl14_8f_dec4x1024.pth \
    --eval_only


================================================
FILE: scripts/train_k400_vitb16_16f_dec4x768.sh
================================================
#!/usr/bin/env sh

exp_dir=runs/k400_vitb16_16f_dec4x768

mkdir -p "${exp_dir}"
python -u -m torch.distributed.run --nproc_per_node 8 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-B/16-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-B-16.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 768 \
    --decoder_num_heads 12 \
    --num_classes 400 \
    --checkpoint_dir "${exp_dir}" \
    --auto_resume \
    --train_list_path /path/to/k400/train.txt \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 16 \
    --sampling_rate 16 \
    --num_spatial_views 3 \
    --num_temporal_views 1 \
  2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"


================================================
FILE: scripts/train_k400_vitb16_32f_dec4x768.sh
================================================
#!/usr/bin/env sh

exp_dir=runs/k400_vitb16_32f_dec4x768

mkdir -p "${exp_dir}"
python -u -m torch.distributed.run --nproc_per_node 8 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-B/16-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-B-16.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 768 \
    --decoder_num_heads 12 \
    --num_classes 400 \
    --checkpoint_dir "${exp_dir}" \
    --auto_resume \
    --train_list_path /path/to/k400/train.txt \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 32 \
    --sampling_rate 8 \
    --num_spatial_views 3 \
    --num_temporal_views 1 \
  2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"


================================================
FILE: scripts/train_k400_vitb16_8f_dec4x768.sh
================================================
#!/usr/bin/env sh

exp_dir=runs/k400_vitb16_8f_dec4x768

mkdir -p "${exp_dir}"
python -u -m torch.distributed.run --nproc_per_node 8 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-B/16-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-B-16.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 768 \
    --decoder_num_heads 12 \
    --num_classes 400 \
    --checkpoint_dir "${exp_dir}" \
    --auto_resume \
    --train_list_path /path/to/k400/train.txt \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 8 \
    --sampling_rate 16 \
    --num_spatial_views 1 \
    --num_temporal_views 3 \
  2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"


================================================
FILE: scripts/train_k400_vitl14_16f_dec4x1024.sh
================================================
#!/usr/bin/env sh

exp_dir=runs/k400_vitl14_16f_dec4x1024

mkdir -p "${exp_dir}"
python -u -m torch.distributed.run --nproc_per_node 8 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-L/14-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-L-14.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 1024 \
    --decoder_num_heads 16 \
    --num_classes 400 \
    --checkpoint_dir "${exp_dir}" \
    --auto_resume \
    --train_list_path /path/to/k400/train.txt \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 2 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 16 \
    --sampling_rate 16 \
    --num_spatial_views 3 \
    --num_temporal_views 1 \
  2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"


================================================
FILE: scripts/train_k400_vitl14_32f_dec4x1024.sh
================================================
#!/usr/bin/env sh

exp_dir=runs/k400_vitl14_32f_dec4x1024

mkdir -p "${exp_dir}"
python -u -m torch.distributed.run --nproc_per_node 8 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-L/14-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-L-14.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 1024 \
    --decoder_num_heads 16 \
    --num_classes 400 \
    --checkpoint_dir "${exp_dir}" \
    --auto_resume \
    --train_list_path /path/to/k400/train.txt \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 4 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 32 \
    --sampling_rate 8 \
    --num_spatial_views 3 \
    --num_temporal_views 1 \
  2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"


================================================
FILE: scripts/train_k400_vitl14_8f_dec4x1024.sh
================================================
#!/usr/bin/env sh

exp_dir=runs/k400_vitl14_8f_dec4x1024

mkdir -p "${exp_dir}"
python -u -m torch.distributed.run --nproc_per_node 8 \
  main.py \
    --num_steps 50000 \
    --backbone "ViT-L/14-lnpre" \
    --backbone_type clip \
    --backbone_path /path/to/clip_models/ViT-L-14.pt \
    --decoder_num_layers 4 \
    --decoder_qkv_dim 1024 \
    --decoder_num_heads 16 \
    --num_classes 400 \
    --checkpoint_dir "${exp_dir}" \
    --auto_resume \
    --train_list_path /path/to/k400/train.txt \
    --val_list_path /path/to/k400/val.txt \
    --batch_size 256 \
    --batch_split 1 \
    --auto_augment rand-m7-n4-mstd0.5-inc1 \
    --mean 0.48145466 0.4578275 0.40821073 \
    --std 0.26862954 0.26130258 0.27577711 \
    --num_workers 12 \
    --num_frames 8 \
    --sampling_rate 16 \
    --num_spatial_views 1 \
    --num_temporal_views 3 \
  2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log"


================================================
FILE: video_dataset/__init__.py
================================================
#!/usr/bin/env python

from .dataloader import setup_arg_parser, create_train_loader, create_val_loader

================================================
FILE: video_dataset/dataloader.py
================================================
#!/usr/bin/env python

import argparse
from typing import Dict

import torch
import torch.distributed as dist

from .dataset import VideoDataset, DummyDataset

def setup_arg_parser(parser: argparse.ArgumentParser):
    parser.add_argument('--train_list_path', type=str,
                        help='path to training data list')
    parser.add_argument('--val_list_path', type=str,
                        help='path to validation data list')
    parser.add_argument('--train_data_root', type=str,
                        help='training samples root directory')
    parser.add_argument('--val_data_root', type=str,
                        help='validation samples root directory')
    parser.add_argument('--data_root', type=str, default='',
                        help='training and validation samples root directory, might be overrided by --train_data_root or --val_data_root')

    parser.add_argument('--batch_size', type=int,
                        help='training batch size on a all GPUs')

    parser.add_argument('--num_spatial_views', type=int, default=1,
                        help='number of spatial crops used for testing (total views = num_spatial_views * num_temporal_views)')
    parser.add_argument('--num_temporal_views', type=int, default=3,
                        help='number of temporal crops used for testing (total views = num_spatial_views * num_temporal_views)')
    parser.add_argument('--num_frames', type=int, default=8,
                        help='number of frames used for each view')
    parser.add_argument('--sampling_rate', type=int, default=16,
                        help='temporal stride for frame sampling, only valid when tsn_sampling is not enabled')
    parser.add_argument('--tsn_sampling', action='store_true',
                        help='enable TSN-style sampling (i.e. sample frames with dynamic stride to cover the whole video)')
    parser.add_argument('--spatial_size', type=int, default=224,
                        help='frame height and width in pixels')

    parser.add_argument('--mean', type=float, nargs='+',
                        help='pixel mean used to normalize the image.')
    parser.add_argument('--std', type=float, nargs='+',
                        help='pixel std used to normalize the image')

    parser.add_argument('--num_workers', type=int, default=10,
                        help='number of DataLoader worker threads')
    
    parser.add_argument('--dummy_dataset', action='store_true',
                        help='use fake datasets that generate all 0 (use for speed test only)')

    parser.add_argument('--auto_augment', type=str,
                        help='auto augment configuration')
    parser.add_argument('--interpolation', type=str, default='bicubic',
                        help='interpolation mode')
    parser.add_argument('--no_mirror', action='store_false', dest='mirror',
                        help='disable mirror for training (frequently used for the something-something dataset)')
    parser.set_defaults(mirror=True)
                        

def _parse_mean_and_std(args: argparse.Namespace) -> Dict[str, torch.Tensor]:
    def parse_mean_or_std(arg, default_value):
        if arg is None:
            return torch.Tensor([default_value] * 3)
        elif len(arg) == 1:
            return torch.Tensor(arg * 3)
        elif len(arg) == 3:
            return torch.Tensor(arg)
        else:
            raise NotImplementedError()
    return {
        'mean': parse_mean_or_std(args.mean, 0.45),
        'std': parse_mean_or_std(args.std, 0.225),
    }


def create_train_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset:
    if args.dummy_dataset:
        return DummyDataset(
            list_path=args.train_list_path,
            num_frames=args.num_frames,
            num_views=1,
            spatial_size=args.spatial_size,
        )

    return VideoDataset(
        list_path=args.train_list_path,
        data_root=args.train_data_root or args.data_root,
        num_spatial_views=1, num_temporal_views=1, random_sample=True,
        auto_augment=args.auto_augment,
        interpolation=args.interpolation,
        mirror=args.mirror,
        num_frames=args.num_frames,
        sampling_rate=-1 if args.tsn_sampling else args.sampling_rate,
        spatial_size=args.spatial_size,
        **_parse_mean_and_std(args),
    )


def create_train_loader(args: argparse.Namespace, resume_step: int = 0) -> torch.utils.data.DataLoader:
    dataset = create_train_dataset(args)
    rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size())

    assert args.batch_size % world_size == 0
    batch_size_per_gpu = args.batch_size // world_size

    # manually create a step-based sampler
    sampler = []
    while len(sampler) * len(dataset) < args.num_steps * args.batch_size:
        g = torch.Generator()
        g.manual_seed(len(sampler))
        indices = torch.randperm(len(dataset), generator=g)
        sampler.append(indices)
    sampler = torch.cat(sampler, dim=0)[:args.num_steps * args.batch_size].view(args.num_steps, args.batch_size)
    sampler = sampler[resume_step:, batch_size_per_gpu * rank: batch_size_per_gpu * (rank + 1)].flatten().tolist()

    loader = torch.utils.data.DataLoader(
        dataset, sampler=sampler, batch_size=batch_size_per_gpu,
        num_workers=args.num_workers, pin_memory=False, drop_last=True,
    )

    return loader


def create_val_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset:
    if args.dummy_dataset:
        return DummyDataset(
            list_path=args.val_list_path,
            num_frames=args.num_frames,
            num_views=args.num_spatial_views * args.num_temporal_views,
            spatial_size=args.spatial_size,
        )

    return VideoDataset(
        list_path=args.val_list_path,
        data_root=args.val_data_root or args.data_root,
        num_spatial_views=args.num_spatial_views,
        num_temporal_views=args.num_temporal_views,
        random_sample=False,
        num_frames=args.num_frames,
        sampling_rate=-1 if args.tsn_sampling else args.sampling_rate,
        spatial_size=args.spatial_size,
        **_parse_mean_and_std(args),
    )


def create_val_loader(args: argparse.Namespace) -> torch.utils.data.Dataset:
    dataset = create_val_dataset(args)
    rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size())

    # sampler for distribued eval
    sampler = list(range(rank, len(dataset), world_size))

    loader = torch.utils.data.DataLoader(
        dataset, sampler=sampler, batch_size=1,
        num_workers=args.num_workers, pin_memory=False,
    )

    return loader


================================================
FILE: video_dataset/dataset.py
================================================
#!/usr/bin/env python

import os, sys
from typing import Optional
import av
import io
import numpy as np

import torch
from torchvision import transforms

from .transform import create_random_augment, random_resized_crop

class VideoDataset(torch.utils.data.Dataset):

    def __init__(
        self, list_path: str, data_root: str,
        num_spatial_views: int, num_temporal_views: int, random_sample: bool,
        num_frames: int, sampling_rate: int, spatial_size: int,
        mean: torch.Tensor, std: torch.Tensor,
        auto_augment: Optional[str] = None, interpolation: str = 'bicubic',
        mirror: bool = False,
    ):
        self.data_root = data_root
        self.interpolation = interpolation
        self.spatial_size = spatial_size

        self.mean, self.std = mean, std
        self.num_frames, self.sampling_rate = num_frames, sampling_rate

        if random_sample:
            assert num_spatial_views == 1 and num_temporal_views == 1
            self.random_sample = True
            self.mirror = mirror
            self.auto_augment = auto_augment
        else:
            assert auto_augment is None and not mirror
            self.random_sample = False
            self.num_temporal_views = num_temporal_views
            self.num_spatial_views = num_spatial_views

        with open(list_path) as f:
            self.data_list = f.read().splitlines()


    def __len__(self):
        return len(self.data_list)
    

    def __getitem__(self, idx):
        line = self.data_list[idx]
        path, label = line.split(' ')
        path = os.path.join(self.data_root, path)
        label = int(label)

        container = av.open(path)
        frames = {}
        for frame in container.decode(video=0):
            frames[frame.pts] = frame
        container.close()
        frames = [frames[k] for k in sorted(frames.keys())]

        if self.random_sample:
            frame_idx = self._random_sample_frame_idx(len(frames))
            frames = [frames[x].to_rgb().to_ndarray() for x in frame_idx]
            frames = torch.as_tensor(np.stack(frames)).float() / 255.

            if self.auto_augment is not None:
                aug_transform = create_random_augment(
                    input_size=(frames.size(1), frames.size(2)),
                    auto_augment=self.auto_augment,
                    interpolation=self.interpolation,
                )
                frames = frames.permute(0, 3, 1, 2) # T, C, H, W
                frames = [transforms.ToPILImage()(frames[i]) for i in range(frames.size(0))]
                frames = aug_transform(frames)
                frames = torch.stack([transforms.ToTensor()(img) for img in frames])
                frames = frames.permute(0, 2, 3, 1)

            frames = (frames - self.mean) / self.std
            frames = frames.permute(3, 0, 1, 2) # C, T, H, W
            frames = random_resized_crop(
                frames, self.spatial_size, self.spatial_size,
            )
            
        else:
            frames = [x.to_rgb().to_ndarray() for x in frames]
            frames = torch.as_tensor(np.stack(frames))
            frames = frames.float() / 255.

            frames = (frames - self.mean) / self.std
            frames = frames.permute(3, 0, 1, 2) # C, T, H, W
            
            if frames.size(-2) < frames.size(-1):
                new_width = frames.size(-1) * self.spatial_size // frames.size(-2)
                new_height = self.spatial_size
            else:
                new_height = frames.size(-2) * self.spatial_size // frames.size(-1)
                new_width = self.spatial_size
            frames = torch.nn.functional.interpolate(
                frames, size=(new_height, new_width),
                mode='bilinear', align_corners=False,
            )

            frames = self._generate_spatial_crops(frames)
            frames = sum([self._generate_temporal_crops(x) for x in frames], [])
            if len(frames) > 1:
                frames = torch.stack(frames)

        return frames, label


    def _generate_temporal_crops(self, frames):
        seg_len = (self.num_frames - 1) * self.sampling_rate + 1
        if frames.size(1) < seg_len:
            frames = torch.cat([frames, frames[:, -1:].repeat(1, seg_len - frames.size(1), 1, 1)], dim=1)
        slide_len = frames.size(1) - seg_len

        crops = []
        for i in range(self.num_temporal_views):
            if self.num_temporal_views == 1:
                st = slide_len // 2
            else:
                st = round(slide_len / (self.num_temporal_views - 1) * i)

            crops.append(frames[:, st: st + self.num_frames * self.sampling_rate: self.sampling_rate])
        
        return crops


    def _generate_spatial_crops(self, frames):
        if self.num_spatial_views == 1:
            assert min(frames.size(-2), frames.size(-1)) >= self.spatial_size
            h_st = (frames.size(-2) - self.spatial_size) // 2
            w_st = (frames.size(-1) - self.spatial_size) // 2
            h_ed, w_ed = h_st + self.spatial_size, w_st + self.spatial_size
            return [frames[:, :, h_st: h_ed, w_st: w_ed]]

        elif self.num_spatial_views == 3:
            assert min(frames.size(-2), frames.size(-1)) == self.spatial_size
            crops = []
            margin = max(frames.size(-2), frames.size(-1)) - self.spatial_size
            for st in (0, margin // 2, margin):
                ed = st + self.spatial_size
                if frames.size(-2) > frames.size(-1):
                    crops.append(frames[:, :, st: ed, :])
                else:
                    crops.append(frames[:, :, :, st: ed])
            return crops
        
        else:
            raise NotImplementedError()


    def _random_sample_frame_idx(self, len):
        frame_indices = []

        if self.sampling_rate < 0: # tsn sample
            seg_size = (len - 1) / self.num_frames
            for i in range(self.num_frames):
                start, end = round(seg_size * i), round(seg_size * (i + 1))
                frame_indices.append(np.random.randint(start, end + 1))
        elif self.sampling_rate * (self.num_frames - 1) + 1 >= len:
            for i in range(self.num_frames):
                frame_indices.append(i * self.sampling_rate if i * self.sampling_rate < len else frame_indices[-1])
        else:
            start = np.random.randint(len - self.sampling_rate * (self.num_frames - 1))
            frame_indices = list(range(start, start + self.sampling_rate * self.num_frames, self.sampling_rate))

        return frame_indices


class DummyDataset(torch.utils.data.Dataset):

    def __init__(self, list_path: str, num_frames: int, num_views: int, spatial_size: int):
        with open(list_path) as f:
            self.len = len(f.read().splitlines())
        self.num_frames = num_frames
        self.num_views = num_views
        self.spatial_size = spatial_size

    def __len__(self):
        return self.len

    def __getitem__(self, _):
        shape = [3, self.num_frames, self.spatial_size, self.spatial_size]
        if self.num_views != 1:
            shape = [self.num_views] + shape
        return torch.zeros(shape), 0


================================================
FILE: video_dataset/rand_augment.py
================================================
#!/usr/bin/env python
# Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/rand_augment.py

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

"""
This implementation is based on
https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
pulished under an Apache License 2.0.

COMMENT FROM ORIGINAL:
AutoAugment, RandAugment, and AugMix for PyTorch
This code implements the searched ImageNet policies with various tweaks and
improvements and does not include any of the search code. AA and RA
Implementation adapted from:
    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
AugMix adapted from:
    https://github.com/google-research/augmix
Papers:
    AutoAugment: Learning Augmentation Policies from Data
    https://arxiv.org/abs/1805.09501
    Learning Data Augmentation Strategies for Object Detection
    https://arxiv.org/abs/1906.11172
    RandAugment: Practical automated data augmentation...
    https://arxiv.org/abs/1909.13719
    AugMix: A Simple Data Processing Method to Improve Robustness and
    Uncertainty https://arxiv.org/abs/1912.02781

Hacked together by / Copyright 2020 Ross Wightman
"""

import math
import numpy as np
import random
import re
import PIL
from PIL import Image, ImageEnhance, ImageOps

_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])

_FILL = (128, 128, 128)

# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL = 10.0

_HPARAMS_DEFAULT = {
    "translate_const": 250,
    "img_mean": _FILL,
}

_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)


def _interpolation(kwargs):
    interpolation = kwargs.pop("resample", Image.BILINEAR)
    if isinstance(interpolation, (list, tuple)):
        return random.choice(interpolation)
    else:
        return interpolation


def _check_args_tf(kwargs):
    if "fillcolor" in kwargs and _PIL_VER < (5, 0):
        kwargs.pop("fillcolor")
    kwargs["resample"] = _interpolation(kwargs)


def shear_x(img, factor, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(
        img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs
    )


def shear_y(img, factor, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(
        img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs
    )


def translate_x_rel(img, pct, **kwargs):
    pixels = pct * img.size[0]
    _check_args_tf(kwargs)
    return img.transform(
        img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
    )


def translate_y_rel(img, pct, **kwargs):
    pixels = pct * img.size[1]
    _check_args_tf(kwargs)
    return img.transform(
        img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
    )


def translate_x_abs(img, pixels, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(
        img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs
    )


def translate_y_abs(img, pixels, **kwargs):
    _check_args_tf(kwargs)
    return img.transform(
        img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs
    )


def rotate(img, degrees, **kwargs):
    _check_args_tf(kwargs)
    if _PIL_VER >= (5, 2):
        return img.rotate(degrees, **kwargs)
    elif _PIL_VER >= (5, 0):
        w, h = img.size
        post_trans = (0, 0)
        rotn_center = (w / 2.0, h / 2.0)
        angle = -math.radians(degrees)
        matrix = [
            round(math.cos(angle), 15),
            round(math.sin(angle), 15),
            0.0,
            round(-math.sin(angle), 15),
            round(math.cos(angle), 15),
            0.0,
        ]

        def transform(x, y, matrix):
            (a, b, c, d, e, f) = matrix
            return a * x + b * y + c, d * x + e * y + f

        matrix[2], matrix[5] = transform(
            -rotn_center[0] - post_trans[0],
            -rotn_center[1] - post_trans[1],
            matrix,
        )
        matrix[2] += rotn_center[0]
        matrix[5] += rotn_center[1]
        return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
    else:
        return img.rotate(degrees, resample=kwargs["resample"])


def auto_contrast(img, **__):
    return ImageOps.autocontrast(img)


def invert(img, **__):
    return ImageOps.invert(img)


def equalize(img, **__):
    return ImageOps.equalize(img)


def solarize(img, thresh, **__):
    return ImageOps.solarize(img, thresh)


def solarize_add(img, add, thresh=128, **__):
    lut = []
    for i in range(256):
        if i < thresh:
            lut.append(min(255, i + add))
        else:
            lut.append(i)
    if img.mode in ("L", "RGB"):
        if img.mode == "RGB" and len(lut) == 256:
            lut = lut + lut + lut
        return img.point(lut)
    else:
        return img


def posterize(img, bits_to_keep, **__):
    if bits_to_keep >= 8:
        return img
    return ImageOps.posterize(img, bits_to_keep)


def contrast(img, factor, **__):
    return ImageEnhance.Contrast(img).enhance(factor)


def color(img, factor, **__):
    return ImageEnhance.Color(img).enhance(factor)


def brightness(img, factor, **__):
    return ImageEnhance.Brightness(img).enhance(factor)


def sharpness(img, factor, **__):
    return ImageEnhance.Sharpness(img).enhance(factor)


def _randomly_negate(v):
    """With 50% prob, negate the value"""
    return -v if random.random() > 0.5 else v


def _rotate_level_to_arg(level, _hparams):
    # range [-30, 30]
    level = (level / _MAX_LEVEL) * 30.0
    level = _randomly_negate(level)
    return (level,)


def _enhance_level_to_arg(level, _hparams):
    # range [0.1, 1.9]
    return ((level / _MAX_LEVEL) * 1.8 + 0.1,)


def _enhance_increasing_level_to_arg(level, _hparams):
    # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
    # range [0.1, 1.9]
    level = (level / _MAX_LEVEL) * 0.9
    level = 1.0 + _randomly_negate(level)
    return (level,)


def _shear_level_to_arg(level, _hparams):
    # range [-0.3, 0.3]
    level = (level / _MAX_LEVEL) * 0.3
    level = _randomly_negate(level)
    return (level,)


def _translate_abs_level_to_arg(level, hparams):
    translate_const = hparams["translate_const"]
    level = (level / _MAX_LEVEL) * float(translate_const)
    level = _randomly_negate(level)
    return (level,)


def _translate_rel_level_to_arg(level, hparams):
    # default range [-0.45, 0.45]
    translate_pct = hparams.get("translate_pct", 0.45)
    level = (level / _MAX_LEVEL) * translate_pct
    level = _randomly_negate(level)
    return (level,)


def _posterize_level_to_arg(level, _hparams):
    # As per Tensorflow TPU EfficientNet impl
    # range [0, 4], 'keep 0 up to 4 MSB of original image'
    # intensity/severity of augmentation decreases with level
    return (int((level / _MAX_LEVEL) * 4),)


def _posterize_increasing_level_to_arg(level, hparams):
    # As per Tensorflow models research and UDA impl
    # range [4, 0], 'keep 4 down to 0 MSB of original image',
    # intensity/severity of augmentation increases with level
    return (4 - _posterize_level_to_arg(level, hparams)[0],)


def _posterize_original_level_to_arg(level, _hparams):
    # As per original AutoAugment paper description
    # range [4, 8], 'keep 4 up to 8 MSB of image'
    # intensity/severity of augmentation decreases with level
    return (int((level / _MAX_LEVEL) * 4) + 4,)


def _solarize_level_to_arg(level, _hparams):
    # range [0, 256]
    # intensity/severity of augmentation decreases with level
    return (int((level / _MAX_LEVEL) * 256),)


def _solarize_increasing_level_to_arg(level, _hparams):
    # range [0, 256]
    # intensity/severity of augmentation increases with level
    return (256 - _solarize_level_to_arg(level, _hparams)[0],)


def _solarize_add_level_to_arg(level, _hparams):
    # range [0, 110]
    return (int((level / _MAX_LEVEL) * 110),)


LEVEL_TO_ARG = {
    "AutoContrast": None,
    "Equalize": None,
    "Invert": None,
    "Rotate": _rotate_level_to_arg,
    # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
    "Posterize": _posterize_level_to_arg,
    "PosterizeIncreasing": _posterize_increasing_level_to_arg,
    "PosterizeOriginal": _posterize_original_level_to_arg,
    "Solarize": _solarize_level_to_arg,
    "SolarizeIncreasing": _solarize_increasing_level_to_arg,
    "SolarizeAdd": _solarize_add_level_to_arg,
    "Color": _enhance_level_to_arg,
    "ColorIncreasing": _enhance_increasing_level_to_arg,
    "Contrast": _enhance_level_to_arg,
    "ContrastIncreasing": _enhance_increasing_level_to_arg,
    "Brightness": _enhance_level_to_arg,
    "BrightnessIncreasing": _enhance_increasing_level_to_arg,
    "Sharpness": _enhance_level_to_arg,
    "SharpnessIncreasing": _enhance_increasing_level_to_arg,
    "ShearX": _shear_level_to_arg,
    "ShearY": _shear_level_to_arg,
    "TranslateX": _translate_abs_level_to_arg,
    "TranslateY": _translate_abs_level_to_arg,
    "TranslateXRel": _translate_rel_level_to_arg,
    "TranslateYRel": _translate_rel_level_to_arg,
}


NAME_TO_OP = {
    "AutoContrast": auto_contrast,
    "Equalize": equalize,
    "Invert": invert,
    "Rotate": rotate,
    "Posterize": posterize,
    "PosterizeIncreasing": posterize,
    "PosterizeOriginal": posterize,
    "Solarize": solarize,
    "SolarizeIncreasing": solarize,
    "SolarizeAdd": solarize_add,
    "Color": color,
    "ColorIncreasing": color,
    "Contrast": contrast,
    "ContrastIncreasing": contrast,
    "Brightness": brightness,
    "BrightnessIncreasing": brightness,
    "Sharpness": sharpness,
    "SharpnessIncreasing": sharpness,
    "ShearX": shear_x,
    "ShearY": shear_y,
    "TranslateX": translate_x_abs,
    "TranslateY": translate_y_abs,
    "TranslateXRel": translate_x_rel,
    "TranslateYRel": translate_y_rel,
}


class AugmentOp:
    """
    Apply for video.
    """

    def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
        hparams = hparams or _HPARAMS_DEFAULT
        self.aug_fn = NAME_TO_OP[name]
        self.level_fn = LEVEL_TO_ARG[name]
        self.prob = prob
        self.magnitude = magnitude
        self.hparams = hparams.copy()
        self.kwargs = {
            "fillcolor": hparams["img_mean"]
            if "img_mean" in hparams
            else _FILL,
            "resample": hparams["interpolation"]
            if "interpolation" in hparams
            else _RANDOM_INTERPOLATION,
        }

        # If magnitude_std is > 0, we introduce some randomness
        # in the usually fixed policy and sample magnitude from a normal distribution
        # with mean `magnitude` and std-dev of `magnitude_std`.
        # NOTE This is my own hack, being tested, not in papers or reference impls.
        self.magnitude_std = self.hparams.get("magnitude_std", 0)

    def __call__(self, img_list):
        if self.prob < 1.0 and random.random() > self.prob:
            return img_list
        magnitude = self.magnitude
        if self.magnitude_std and self.magnitude_std > 0:
            magnitude = random.gauss(magnitude, self.magnitude_std)
        magnitude = min(_MAX_LEVEL, max(0, magnitude))  # clip to valid range
        level_args = (
            self.level_fn(magnitude, self.hparams)
            if self.level_fn is not None
            else ()
        )

        if isinstance(img_list, list):
            return [
                self.aug_fn(img, *level_args, **self.kwargs) for img in img_list
            ]
        else:
            return self.aug_fn(img_list, *level_args, **self.kwargs)


_RAND_TRANSFORMS = [
    "AutoContrast",
    "Equalize",
    "Invert",
    "Rotate",
    "Posterize",
    "Solarize",
    "SolarizeAdd",
    "Color",
    "Contrast",
    "Brightness",
    "Sharpness",
    "ShearX",
    "ShearY",
    "TranslateXRel",
    "TranslateYRel",
]


_RAND_INCREASING_TRANSFORMS = [
    "AutoContrast",
    "Equalize",
    "Invert",
    "Rotate",
    "PosterizeIncreasing",
    "SolarizeIncreasing",
    "SolarizeAdd",
    "ColorIncreasing",
    "ContrastIncreasing",
    "BrightnessIncreasing",
    "SharpnessIncreasing",
    "ShearX",
    "ShearY",
    "TranslateXRel",
    "TranslateYRel",
]


# These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = {
    "Rotate": 0.3,
    "ShearX": 0.2,
    "ShearY": 0.2,
    "TranslateXRel": 0.1,
    "TranslateYRel": 0.1,
    "Color": 0.025,
    "Sharpness": 0.025,
    "AutoContrast": 0.025,
    "Solarize": 0.005,
    "SolarizeAdd": 0.005,
    "Contrast": 0.005,
    "Brightness": 0.005,
    "Equalize": 0.005,
    "Posterize": 0,
    "Invert": 0,
}


def _select_rand_weights(weight_idx=0, transforms=None):
    transforms = transforms or _RAND_TRANSFORMS
    assert weight_idx == 0  # only one set of weights currently
    rand_weights = _RAND_CHOICE_WEIGHTS_0
    probs = [rand_weights[k] for k in transforms]
    probs /= np.sum(probs)
    return probs


def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
    hparams = hparams or _HPARAMS_DEFAULT
    transforms = transforms or _RAND_TRANSFORMS
    return [
        AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
        for name in transforms
    ]


class RandAugment:
    def __init__(self, ops, num_layers=2, choice_weights=None):
        self.ops = ops
        self.num_layers = num_layers
        self.choice_weights = choice_weights

    def __call__(self, img):
        # no replacement when using weighted choice
        ops = np.random.choice(
            self.ops,
            self.num_layers,
            replace=self.choice_weights is None,
            p=self.choice_weights,
        )
        for op in ops:
            img = op(img)
        return img


def rand_augment_transform(config_str, hparams):
    """
    RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719

    Create a RandAugment transform
    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
    sections, not order sepecific determine
        'm' - integer magnitude of rand augment
        'n' - integer num layers (number of transform ops selected per image)
        'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
        'mstd' -  float std deviation of magnitude noise applied
        'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
    Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
    'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
    :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
    :return: A PyTorch compatible Transform
    """
    magnitude = _MAX_LEVEL  # default to _MAX_LEVEL for magnitude (currently 10)
    num_layers = 2  # default to 2 ops per image
    weight_idx = None  # default to no probability weights for op choice
    transforms = _RAND_TRANSFORMS
    config = config_str.split("-")
    assert config[0] == "rand"
    config = config[1:]
    for c in config:
        cs = re.split(r"(\d.*)", c)
        if len(cs) < 2:
            continue
        key, val = cs[:2]
        if key == "mstd":
            # noise param injected via hparams for now
            hparams.setdefault("magnitude_std", float(val))
        elif key == "inc":
            if bool(val):
                transforms = _RAND_INCREASING_TRANSFORMS
        elif key == "m":
            magnitude = int(val)
        elif key == "n":
            num_layers = int(val)
        elif key == "w":
            weight_idx = int(val)
        else:
            assert NotImplementedError
    ra_ops = rand_augment_ops(
        magnitude=magnitude, hparams=hparams, transforms=transforms
    )
    choice_weights = (
        None if weight_idx is None else _select_rand_weights(weight_idx)
    )
    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)


================================================
FILE: video_dataset/random_erasing.py
================================================
#!/usr/bin/env python
# Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/random_erasing.py

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

"""
This implementation is based on
https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
pulished under an Apache License 2.0.

COMMENT FROM ORIGINAL:
Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
Copyright Zhun Zhong & Liang Zheng
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import random
import torch


def _get_pixels(
    per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
):
    # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
    # paths, flip the order so normal is run on CPU if this becomes a problem
    # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
    if per_pixel:
        return torch.empty(patch_size, dtype=dtype, device=device).normal_()
    elif rand_color:
        return torch.empty(
            (patch_size[0], 1, 1), dtype=dtype, device=device
        ).normal_()
    else:
        return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)


class RandomErasing:
    """Randomly selects a rectangle region in an image and erases its pixels.
        'Random Erasing Data Augmentation' by Zhong et al.
        See https://arxiv.org/pdf/1708.04896.pdf
        This variant of RandomErasing is intended to be applied to either a batch
        or single image tensor after it has been normalized by dataset mean and std.
    Args:
         probability: Probability that the Random Erasing operation will be performed.
         min_area: Minimum percentage of erased area wrt input image area.
         max_area: Maximum percentage of erased area wrt input image area.
         min_aspect: Minimum aspect ratio of erased area.
         mode: pixel color mode, one of 'const', 'rand', or 'pixel'
            'const' - erase block is constant color of 0 for all channels
            'rand'  - erase block is same per-channel random (normal) color
            'pixel' - erase block is per-pixel random (normal) color
        max_count: maximum number of erasing blocks per image, area per box is scaled by count.
            per-image count is randomly chosen between 1 and this value.
    """

    def __init__(
        self,
        probability=0.5,
        min_area=0.02,
        max_area=1 / 3,
        min_aspect=0.3,
        max_aspect=None,
        mode="const",
        min_count=1,
        max_count=None,
        num_splits=0,
        device="cuda",
        cube=True,
    ):
        self.probability = probability
        self.min_area = min_area
        self.max_area = max_area
        max_aspect = max_aspect or 1 / min_aspect
        self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
        self.min_count = min_count
        self.max_count = max_count or min_count
        self.num_splits = num_splits
        mode = mode.lower()
        self.rand_color = False
        self.per_pixel = False
        self.cube = cube
        if mode == "rand":
            self.rand_color = True  # per block random normal
        elif mode == "pixel":
            self.per_pixel = True  # per pixel random normal
        else:
            assert not mode or mode == "const"
        self.device = device

    def _erase(self, img, chan, img_h, img_w, dtype):
        if random.random() > self.probability:
            return
        area = img_h * img_w
        count = (
            self.min_count
            if self.min_count == self.max_count
            else random.randint(self.min_count, self.max_count)
        )
        for _ in range(count):
            for _ in range(10):
                target_area = (
                    random.uniform(self.min_area, self.max_area) * area / count
                )
                aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
                h = int(round(math.sqrt(target_area * aspect_ratio)))
                w = int(round(math.sqrt(target_area / aspect_ratio)))
                if w < img_w and h < img_h:
                    top = random.randint(0, img_h - h)
                    left = random.randint(0, img_w - w)
                    img[:, top : top + h, left : left + w] = _get_pixels(
                        self.per_pixel,
                        self.rand_color,
                        (chan, h, w),
                        dtype=dtype,
                        device=self.device,
                    )
                    break

    def _erase_cube(
        self,
        img,
        batch_start,
        batch_size,
        chan,
        img_h,
        img_w,
        dtype,
    ):
        if random.random() > self.probability:
            return
        area = img_h * img_w
        count = (
            self.min_count
            if self.min_count == self.max_count
            else random.randint(self.min_count, self.max_count)
        )
        for _ in range(count):
            for _ in range(100):
                target_area = (
                    random.uniform(self.min_area, self.max_area) * area / count
                )
                aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
                h = int(round(math.sqrt(target_area * aspect_ratio)))
                w = int(round(math.sqrt(target_area / aspect_ratio)))
                if w < img_w and h < img_h:
                    top = random.randint(0, img_h - h)
                    left = random.randint(0, img_w - w)
                    for i in range(batch_start, batch_size):
                        img_instance = img[i]
                        img_instance[
                            :, top : top + h, left : left + w
                        ] = _get_pixels(
                            self.per_pixel,
                            self.rand_color,
                            (chan, h, w),
                            dtype=dtype,
                            device=self.device,
                        )
                    break

    def __call__(self, input):
        if len(input.size()) == 3:
            self._erase(input, *input.size(), input.dtype)
        else:
            batch_size, chan, img_h, img_w = input.size()
            # skip first slice of batch if num_splits is set (for clean portion of samples)
            batch_start = (
                batch_size // self.num_splits if self.num_splits > 1 else 0
            )
            if self.cube:
                self._erase_cube(
                    input,
                    batch_start,
                    batch_size,
                    chan,
                    img_h,
                    img_w,
                    input.dtype,
                )
            else:
                for i in range(batch_start, batch_size):
                    self._erase(input[i], chan, img_h, img_w, input.dtype)
        return input


================================================
FILE: video_dataset/transform.py
================================================
#!/usr/bin/env python3
# Originate from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/transform.py
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import logging
import math
import numpy as np

# import cv2
import random
import torch
import torchvision as tv
import torchvision.transforms.functional as F
from PIL import Image, ImageFilter
from torchvision import transforms

from .rand_augment import rand_augment_transform
from .random_erasing import RandomErasing

_pil_interpolation_to_str = {
    Image.NEAREST: "PIL.Image.NEAREST",
    Image.BILINEAR: "PIL.Image.BILINEAR",
    Image.BICUBIC: "PIL.Image.BICUBIC",
    Image.LANCZOS: "PIL.Image.LANCZOS",
    Image.HAMMING: "PIL.Image.HAMMING",
    Image.BOX: "PIL.Image.BOX",
}


_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)


def _pil_interp(method):
    if method == "bicubic":
        return Image.BICUBIC
    elif method == "lanczos":
        return Image.LANCZOS
    elif method == "hamming":
        return Image.HAMMING
    else:
        return Image.BILINEAR


logger = logging.getLogger(__name__)


def random_short_side_scale_jitter(
    images, min_size, max_size, boxes=None, inverse_uniform_sampling=False
):
    """
    Perform a spatial short scale jittering on the given images and
    corresponding boxes.
    Args:
        images (tensor): images to perform scale jitter. Dimension is
            `num frames` x `channel` x `height` x `width`.
        min_size (int): the minimal size to scale the frames.
        max_size (int): the maximal size to scale the frames.
        boxes (ndarray): optional. Corresponding boxes to images.
            Dimension is `num boxes` x 4.
        inverse_uniform_sampling (bool): if True, sample uniformly in
            [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
            scale. If False, take a uniform sample from [min_scale, max_scale].
    Returns:
        (tensor): the scaled images with dimension of
            `num frames` x `channel` x `new height` x `new width`.
        (ndarray or None): the scaled boxes with dimension of
            `num boxes` x 4.
    """
    if inverse_uniform_sampling:
        size = int(
            round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
        )
    else:
        size = int(round(np.random.uniform(min_size, max_size)))

    height = images.shape[2]
    width = images.shape[3]
    if (width <= height and width == size) or (
        height <= width and height == size
    ):
        return images, boxes
    new_width = size
    new_height = size
    if width < height:
        new_height = int(math.floor((float(height) / width) * size))
        if boxes is not None:
            boxes = boxes * float(new_height) / height
    else:
        new_width = int(math.floor((float(width) / height) * size))
        if boxes is not None:
            boxes = boxes * float(new_width) / width

    return (
        torch.nn.functional.interpolate(
            images,
            size=(new_height, new_width),
            mode="bilinear",
            align_corners=False,
        ),
        boxes,
    )


def crop_boxes(boxes, x_offset, y_offset):
    """
    Peform crop on the bounding boxes given the offsets.
    Args:
        boxes (ndarray or None): bounding boxes to peform crop. The dimension
            is `num boxes` x 4.
        x_offset (int): cropping offset in the x axis.
        y_offset (int): cropping offset in the y axis.
    Returns:
        cropped_boxes (ndarray or None): the cropped boxes with dimension of
            `num boxes` x 4.
    """
    cropped_boxes = boxes.copy()
    cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
    cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset

    return cropped_boxes


def random_crop(images, size, boxes=None):
    """
    Perform random spatial crop on the given images and corresponding boxes.
    Args:
        images (tensor): images to perform random crop. The dimension is
            `num frames` x `channel` x `height` x `width`.
        size (int): the size of height and width to crop on the image.
        boxes (ndarray or None): optional. Corresponding boxes to images.
            Dimension is `num boxes` x 4.
    Returns:
        cropped (tensor): cropped images with dimension of
            `num frames` x `channel` x `size` x `size`.
        cropped_boxes (ndarray or None): the cropped boxes with dimension of
            `num boxes` x 4.
    """
    if images.shape[2] == size and images.shape[3] == size:
        return images, boxes
    height = images.shape[2]
    width = images.shape[3]
    y_offset = 0
    if height > size:
        y_offset = int(np.random.randint(0, height - size))
    x_offset = 0
    if width > size:
        x_offset = int(np.random.randint(0, width - size))
    cropped = images[
        :, :, y_offset : y_offset + size, x_offset : x_offset + size
    ]

    cropped_boxes = (
        crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
    )

    return cropped, cropped_boxes


def horizontal_flip(prob, images, boxes=None):
    """
    Perform horizontal flip on the given images and corresponding boxes.
    Args:
        prob (float): probility to flip the images.
        images (tensor): images to perform horizontal flip, the dimension is
            `num frames` x `channel` x `height` x `width`.
        boxes (ndarray or None): optional. Corresponding boxes to images.
            Dimension is `num boxes` x 4.
    Returns:
        images (tensor): images with dimension of
            `num frames` x `channel` x `height` x `width`.
        flipped_boxes (ndarray or None): the flipped boxes with dimension of
            `num boxes` x 4.
    """
    if boxes is None:
        flipped_boxes = None
    else:
        flipped_boxes = boxes.copy()

    if np.random.uniform() < prob:
        images = images.flip((-1))

        if len(images.shape) == 3:
            width = images.shape[2]
        elif len(images.shape) == 4:
            width = images.shape[3]
        else:
            raise NotImplementedError("Dimension does not supported")
        if boxes is not None:
            flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1

    return images, flipped_boxes


def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
    """
    Perform uniform spatial sampling on the images and corresponding boxes.
    Args:
        images (tensor): images to perform uniform crop. The dimension is
            `num frames` x `channel` x `height` x `width`.
        size (int): size of height and weight to crop the images.
        spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
            is larger than height. Or 0, 1, or 2 for top, center, and bottom
            crop if height is larger than width.
        boxes (ndarray or None): optional. Corresponding boxes to images.
            Dimension is `num boxes` x 4.
        scale_size (int): optinal. If not None, resize the images to scale_size before
            performing any crop.
    Returns:
        cropped (tensor): images with dimension of
            `num frames` x `channel` x `size` x `size`.
        cropped_boxes (ndarray or None): the cropped boxes with dimension of
            `num boxes` x 4.
    """
    assert spatial_idx in [0, 1, 2]
    ndim = len(images.shape)
    if ndim == 3:
        images = images.unsqueeze(0)
    height = images.shape[2]
    width = images.shape[3]

    if scale_size is not None:
        if width <= height:
            width, height = scale_size, int(height / width * scale_size)
        else:
            width, height = int(width / height * scale_size), scale_size
        images = torch.nn.functional.interpolate(
            images,
            size=(height, width),
            mode="bilinear",
            align_corners=False,
        )

    y_offset = int(math.ceil((height - size) / 2))
    x_offset = int(math.ceil((width - size) / 2))

    if height > width:
        if spatial_idx == 0:
            y_offset = 0
        elif spatial_idx == 2:
            y_offset = height - size
    else:
        if spatial_idx == 0:
            x_offset = 0
        elif spatial_idx == 2:
            x_offset = width - size
    cropped = images[
        :, :, y_offset : y_offset + size, x_offset : x_offset + size
    ]
    cropped_boxes = (
        crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
    )
    if ndim == 3:
        cropped = cropped.squeeze(0)
    return cropped, cropped_boxes


def clip_boxes_to_image(boxes, height, width):
    """
    Clip an array of boxes to an image with the given height and width.
    Args:
        boxes (ndarray): bounding boxes to perform clipping.
            Dimension is `num boxes` x 4.
        height (int): given image height.
        width (int): given image width.
    Returns:
        clipped_boxes (ndarray): the clipped boxes with dimension of
            `num boxes` x 4.
    """
    clipped_boxes = boxes.copy()
    clipped_boxes[:, [0, 2]] = np.minimum(
        width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
    )
    clipped_boxes[:, [1, 3]] = np.minimum(
        height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
    )
    return clipped_boxes


def blend(images1, images2, alpha):
    """
    Blend two images with a given weight alpha.
    Args:
        images1 (tensor): the first images to be blended, the dimension is
            `num frames` x `channel` x `height` x `width`.
        images2 (tensor): the second images to be blended, the dimension is
            `num frames` x `channel` x `height` x `width`.
        alpha (float): the blending weight.
    Returns:
        (tensor): blended images, the dimension is
            `num frames` x `channel` x `height` x `width`.
    """
    return images1 * alpha + images2 * (1 - alpha)


def grayscale(images):
    """
    Get the grayscale for the input images. The channels of images should be
    in order BGR.
    Args:
        images (tensor): the input images for getting grayscale. Dimension is
            `num frames` x `channel` x `height` x `width`.
    Returns:
        img_gray (tensor): blended images, the dimension is
            `num frames` x `channel` x `height` x `width`.
    """
    # R -> 0.299, G -> 0.587, B -> 0.114.
    img_gray = torch.tensor(images)
    gray_channel = (
        0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
    )
    img_gray[:, 0] = gray_channel
    img_gray[:, 1] = gray_channel
    img_gray[:, 2] = gray_channel
    return img_gray


def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
    """
    Perfrom a color jittering on the input images. The channels of images
    should be in order BGR.
    Args:
        images (tensor): images to perform color jitter. Dimension is
            `num frames` x `channel` x `height` x `width`.
        img_brightness (float): jitter ratio for brightness.
        img_contrast (float): jitter ratio for contrast.
        img_saturation (float): jitter ratio for saturation.
    Returns:
        images (tensor): the jittered images, the dimension is
            `num frames` x `channel` x `height` x `width`.
    """

    jitter = []
    if img_brightness != 0:
        jitter.append("brightness")
    if img_contrast != 0:
        jitter.append("contrast")
    if img_saturation != 0:
        jitter.append("saturation")

    if len(jitter) > 0:
        order = np.random.permutation(np.arange(len(jitter)))
        for idx in range(0, len(jitter)):
            if jitter[order[idx]] == "brightness":
                images = brightness_jitter(img_brightness, images)
            elif jitter[order[idx]] == "contrast":
                images = contrast_jitter(img_contrast, images)
            elif jitter[order[idx]] == "saturation":
                images = saturation_jitter(img_saturation, images)
    return images


def brightness_jitter(var, images):
    """
    Perfrom brightness jittering on the input images. The channels of images
    should be in order BGR.
    Args:
        var (float): jitter ratio for brightness.
        images (tensor): images to perform color jitter. Dimension is
            `num frames` x `channel` x `height` x `width`.
    Returns:
        images (tensor): the jittered images, the dimension is
            `num frames` x `channel` x `height` x `width`.
    """
    alpha = 1.0 + np.random.uniform(-var, var)

    img_bright = torch.zeros(images.shape)
    images = blend(images, img_bright, alpha)
    return images


def contrast_jitter(var, images):
    """
    Perfrom contrast jittering on the input images. The channels of images
    should be in order BGR.
    Args:
        var (float): jitter ratio for contrast.
        images (tensor): images to perform color jitter. Dimension is
            `num frames` x `channel` x `height` x `width`.
    Returns:
        images (tensor): the jittered images, the dimension is
            `num frames` x `channel` x `height` x `width`.
    """
    alpha = 1.0 + np.random.uniform(-var, var)

    img_gray = grayscale(images)
    img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
    images = blend(images, img_gray, alpha)
    return images


def saturation_jitter(var, images):
    """
    Perfrom saturation jittering on the input images. The channels of images
    should be in order BGR.
    Args:
        var (float): jitter ratio for saturation.
        images (tensor): images to perform color jitter. Dimension is
            `num frames` x `channel` x `height` x `width`.
    Returns:
        images (tensor): the jittered images, the dimension is
            `num frames` x `channel` x `height` x `width`.
    """
    alpha = 1.0 + np.random.uniform(-var, var)
    img_gray = grayscale(images)
    images = blend(images, img_gray, alpha)

    return images


def lighting_jitter(images, alphastd, eigval, eigvec):
    """
    Perform AlexNet-style PCA jitter on the given images.
    Args:
        images (tensor): images to perform lighting jitter. Dimension is
            `num frames` x `channel` x `height` x `width`.
        alphastd (float): jitter ratio for PCA jitter.
        eigval (list): eigenvalues for PCA jitter.
        eigvec (list[list]): eigenvectors for PCA jitter.
    Returns:
        out_images (tensor): the jittered images, the dimension is
            `num frames` x `channel` x `height` x `width`.
    """
    if alphastd == 0:
        return images
    # generate alpha1, alpha2, alpha3.
    alpha = np.random.normal(0, alphastd, size=(1, 3))
    eig_vec = np.array(eigvec)
    eig_val = np.reshape(eigval, (1, 3))
    rgb = np.sum(
        eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
        axis=1,
    )
    out_images = torch.zeros_like(images)
    if len(images.shape) == 3:
        # C H W
        channel_dim = 0
    elif len(images.shape) == 4:
        # T C H W
        channel_dim = 1
    else:
        raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")

    for idx in range(images.shape[channel_dim]):
        # C H W
        if len(images.shape) == 3:
            out_images[idx] = images[idx] + rgb[2 - idx]
        # T C H W
        elif len(images.shape) == 4:
            out_images[:, idx] = images[:, idx] + rgb[2 - idx]
        else:
            raise NotImplementedError(
                f"Unsupported dimension {len(images.shape)}"
            )

    return out_images


def color_normalization(images, mean, stddev):
    """
    Perform color nomration on the given images.
    Args:
        images (tensor): images to perform color normalization. Dimension is
            `num frames` x `channel` x `height` x `width`.
        mean (list): mean values for normalization.
        stddev (list): standard deviations for normalization.

    Returns:
        out_images (tensor): the noramlized images, the dimension is
            `num frames` x `channel` x `height` x `width`.
    """
    if len(images.shape) == 3:
        assert (
            len(mean) == images.shape[0]
        ), "channel mean not computed properly"
        assert (
            len(stddev) == images.shape[0]
        ), "channel stddev not computed properly"
    elif len(images.shape) == 4:
        assert (
            len(mean) == images.shape[1]
        ), "channel mean not computed properly"
        assert (
            len(stddev) == images.shape[1]
        ), "channel stddev not computed properly"
    else:
        raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")

    out_images = torch.zeros_like(images)
    for idx in range(len(mean)):
        # C H W
        if len(images.shape) == 3:
            out_images[idx] = (images[idx] - mean[idx]) / stddev[idx]
        elif len(images.shape) == 4:
            out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
        else:
            raise NotImplementedError(
                f"Unsupported dimension {len(images.shape)}"
            )
    return out_images


def _get_param_spatial_crop(
    scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False
):
    """
    Given scale, ratio, height and width, return sampled coordinates of the videos.
    """
    for _ in range(num_repeat):
        area = height * width
        target_area = random.uniform(*scale) * area
        if log_scale:
            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
            aspect_ratio = math.exp(random.uniform(*log_ratio))
        else:
            aspect_ratio = random.uniform(*ratio)

        w = int(round(math.sqrt(target_area * aspect_ratio)))
        h = int(round(math.sqrt(target_area / aspect_ratio)))

        if np.random.uniform() < 0.5 and switch_hw:
            w, h = h, w

        if 0 < w <= width and 0 < h <= height:
            i = random.randint(0, height - h)
            j = random.randint(0, width - w)
            return i, j, h, w

    # Fallback to central crop
    in_ratio = float(width) / float(height)
    if in_ratio < min(ratio):
        w = width
        h = int(round(w / min(ratio)))
    elif in_ratio > max(ratio):
        h = height
        w = int(round(h * max(ratio)))
    else:  # whole image
        w = width
        h = height
    i = (height - h) // 2
    j = (width - w) // 2
    return i, j, h, w


def random_resized_crop(
    images,
    target_height,
    target_width,
    scale=(0.08, 1.0),
    ratio=(3.0 / 4.0, 4.0 / 3.0),
):
    """
    Crop the given images to random size and aspect ratio. A crop of random
    size (default: of 0.08 to 1.0) of the original size and a random aspect
    ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This
    crop is finally resized to given size. This is popularly used to train the
    Inception networks.

    Args:
        images: Images to perform resizing and cropping.
        target_height: Desired height after cropping.
        target_width: Desired width after cropping.
        scale: Scale range of Inception-style area based random resizing.
        ratio: Aspect ratio range of Inception-style area based random resizing.
    """

    height = images.shape[2]
    width = images.shape[3]

    i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
    cropped = images[:, :, i : i + h, j : j + w]
    return torch.nn.functional.interpolate(
        cropped,
        size=(target_height, target_width),
        mode="bilinear",
        align_corners=False,
    )


def random_resized_crop_with_shift(
    images,
    target_height,
    target_width,
    scale=(0.8, 1.0),
    ratio=(3.0 / 4.0, 4.0 / 3.0),
):
    """
    This is similar to random_resized_crop. However, it samples two different
    boxes (for cropping) for the first and last frame. It then linearly
    interpolates the two boxes for other frames.

    Args:
        images: Images to perform resizing and cropping.
        target_height: Desired height after cropping.
        target_width: Desired width after cropping.
        scale: Scale range of Inception-style area based random resizing.
        ratio: Aspect ratio range of Inception-style area based random resizing.
    """
    t = images.shape[1]
    height = images.shape[2]
    width = images.shape[3]

    i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
    i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width)
    i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]
    j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]
    h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]
    w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]
    out = torch.zeros((3, t, target_height, target_width))
    for ind in range(t):
        out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(
            images[
                :,
                ind : ind + 1,
                i_s[ind] : i_s[ind] + h_s[ind],
                j_s[ind] : j_s[ind] + w_s[ind],
            ],
            size=(target_height, target_width),
            mode="bilinear",
            align_corners=False,
        )
    return out


def create_random_augment(
    input_size,
    auto_augment=None,
    interpolation="bilinear",
):
    """
    Get video randaug transform.

    Args:
        input_size: The size of the input video in tuple.
        auto_augment: Parameters for randaug. An example:
            "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number
            of operations to apply).
        interpolation: Interpolation method.
    """
    if isinstance(input_size, tuple):
        img_size = input_size[-2:]
    else:
        img_size = input_size

    if auto_augment:
        assert isinstance(auto_augment, str)
        if isinstance(img_size, tuple):
            img_size_min = min(img_size)
        else:
            img_size_min = img_size
        aa_params = {"translate_const": int(img_size_min * 0.45)}
        if interpolation and interpolation != "random":
            aa_params["interpolation"] = _pil_interp(interpolation)
        if auto_augment.startswith("rand"):
            return transforms.Compose(
                [rand_augment_transform(auto_augment, aa_params)]
            )
    raise NotImplementedError


def random_sized_crop_img(
    im,
    size,
    jitter_scale=(0.08, 1.0),
    jitter_aspect=(3.0 / 4.0, 4.0 / 3.0),
    max_iter=10,
):
    """
    Performs Inception-style cropping (used for training).
    """
    assert (
        len(im.shape) == 3
    ), "Currently only support image for random_sized_crop"
    h, w = im.shape[1:3]
    i, j, h, w = _get_param_spatial_crop(
        scale=jitter_scale,
        ratio=jitter_aspect,
        height=h,
        width=w,
        num_repeat=max_iter,
        log_scale=False,
        switch_hw=True,
    )
    cropped = im[:, i : i + h, j : j + w]
    return torch.nn.functional.interpolate(
        cropped.unsqueeze(0),
        size=(size, size),
        mode="bilinear",
        align_corners=False,
    ).squeeze(0)


# The following code are modified based on timm lib, we will replace the following
# contents with dependency from PyTorchVideo.
# https://github.com/facebookresearch/pytorchvideo
class RandomResizedCropAndInterpolation:
    """Crop the given PIL Image to random size and aspect ratio with random interpolation.
    A crop of random size (default: of 0.08 to 1.0) of the original size and a random
    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.
    Args:
        size: expected output size of each edge
        scale: range of size of the origin size cropped
        ratio: range of aspect ratio of the origin aspect ratio cropped
        interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(
        self,
        size,
        scale=(0.08, 1.0),
        ratio=(3.0 / 4.0, 4.0 / 3.0),
        interpolation="bilinear",
    ):
        if isinstance(size, tuple):
            self.size = size
        else:
            self.size = (size, size)
        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
            print("range should be of kind (min, max)")

        if interpolation == "random":
            self.interpolation = _RANDOM_INTERPOLATION
        else:
            self.interpolation = _pil_interp(interpolation)
        self.scale = scale
        self.ratio = ratio

    @staticmethod
    def get_params(img, scale, ratio):
        """Get parameters for ``crop`` for a random sized crop.
        Args:
            img (PIL Image): Image to be cropped.
            scale (tuple): range of size of the origin size cropped
            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
                sized crop.
        """
        area = img.size[0] * img.size[1]

        for _ in range(10):
            target_area = random.uniform(*scale) * area
            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
            aspect_ratio = math.exp(random.uniform(*log_ratio))

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if w <= img.size[0] and h <= img.size[1]:
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w

        # Fallback to central crop
        in_ratio = img.size[0] / img.size[1]
        if in_ratio < min(ratio):
            w = img.size[0]
            h = int(round(w / min(ratio)))
        elif in_ratio > max(ratio):
            h = img.size[1]
            w = int(round(h * max(ratio)))
        else:  # whole image
            w = img.size[0]
            h = img.size[1]
        i = (img.size[1] - h) // 2
        j = (img.size[0] - w) // 2
        return i, j, h, w

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped and resized.
        Returns:
            PIL Image: Randomly cropped and resized image.
        """
        i, j, h, w = self.get_params(img, self.scale, self.ratio)
        if isinstance(self.interpolation, (tuple, list)):
            interpolation = random.choice(self.interpolation)
        else:
            interpolation = self.interpolation
        return F.resized_crop(img, i, j, h, w, self.size, interpolation)

    def __repr__(self):
        if isinstance(self.interpolation, (tuple, list)):
            interpolate_str = " ".join(
                [_pil_interpolation_to_str[x] for x in self.interpolation]
            )
        else:
            interpolate_str = _pil_interpolation_to_str[self.interpolation]
        format_string = self.__class__.__name__ + "(size={0}".format(self.size)
        format_string += ", scale={0}".format(
            tuple(round(s, 4) for s in self.scale)
        )
        format_string += ", ratio={0}".format(
            tuple(round(r, 4) for r in self.ratio)
        )
        format_string += ", interpolation={0})".format(interpolate_str)
        return format_string


================================================
FILE: vision_transformer.py
================================================
#!/usr/bin/env python

from collections import OrderedDict
import numpy as np
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

'''
QuickGELU and LayerNorm w/ fp16 from official CLIP repo
(https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py)
'''
class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)

class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class Attention(nn.Module):
    '''
    A generalized attention module with more flexibility.
    '''

    def __init__(
        self, q_in_dim: int, k_in_dim: int, v_in_dim: int,
        qk_proj_dim: int, v_proj_dim: int, num_heads: int, out_dim: int,
        return_all_features: bool = False,
    ):
        super().__init__()

        self.q_proj = nn.Linear(q_in_dim, qk_proj_dim)
        self.k_proj = nn.Linear(k_in_dim, qk_proj_dim)
        self.v_proj = nn.Linear(v_in_dim, v_proj_dim)
        self.out_proj = nn.Linear(v_proj_dim, out_dim)

        self.num_heads = num_heads
        self.return_all_features = return_all_features
        assert qk_proj_dim % num_heads == 0 and v_proj_dim % num_heads == 0

        self._initialize_weights()


    def _initialize_weights(self):
        for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0.)


    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3
        N = q.size(0); assert k.size(0) == N and v.size(0) == N
        Lq, Lkv = q.size(1), k.size(1); assert v.size(1) == Lkv

        q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)
        
        H = self.num_heads
        Cqk, Cv = q.size(-1) // H, v.size(-1) // H

        q = q.view(N, Lq, H, Cqk)
        k = k.view(N, Lkv, H, Cqk)
        v = v.view(N, Lkv, H, Cv)

        aff = torch.einsum('nqhc,nkhc->nqkh', q / (Cqk ** 0.5), k)
        aff = aff.softmax(dim=-2)
        mix = torch.einsum('nqlh,nlhc->nqhc', aff, v)

        out = self.out_proj(mix.flatten(-2))

        if self.return_all_features:
            return dict(q=q, k=k, v=v, aff=aff, out=out)
        else:
            return out


class PatchEmbed2D(nn.Module):

    def __init__(
        self,
        patch_size: Tuple[int, int] = (16, 16),
        in_channels: int = 3,
        embed_dim: int = 768,
    ):
        super().__init__()

        self.patch_size = patch_size
        self.in_channels = in_channels

        self.proj = nn.Linear(np.prod(patch_size) * in_channels, embed_dim)


    def _initialize_weights(self, x):
        nn.init.kaiming_normal_(self.proj.weight, 0.)
        nn.init.constant_(self.proj.bias, 0.)


    def forward(self, x: torch.Tensor):
        B, C, H, W = x.size()
        pH, pW = self.patch_size

        assert C == self.in_channels and H % pH == 0 and W % pW == 0

        x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3).flatten(1, 2)
        x = self.proj(x)
        
        return x

class TransformerEncoderLayer(nn.Module):

    def __init__(
        self,
        in_feature_dim: int = 768,
        qkv_dim: int = 768,
        num_heads: int = 12,
        mlp_factor: float = 4.0,
        mlp_dropout: float = 0.0,
        act: nn.Module = QuickGELU,
        return_all_features: bool = False,
    ):
        super().__init__()

        self.return_all_features = return_all_features

        self.attn = Attention(
            q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,
            qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim,
            return_all_features=return_all_features,
        )

        mlp_dim = round(mlp_factor * in_feature_dim)
        self.mlp = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(in_feature_dim, mlp_dim)),
            ('act', act()),
            ('dropout', nn.Dropout(mlp_dropout)),
            ('fc2', nn.Linear(mlp_dim, in_feature_dim)),
        ]))

        self.norm1 = LayerNorm(in_feature_dim)
        self.norm2 = LayerNorm(in_feature_dim)

        self._initialize_weights()


    def _initialize_weights(self):
        for m in (self.mlp[0], self.mlp[-1]):
            nn.init.xavier_uniform_(m.weight)
            nn.init.normal_(m.bias, std=1e-6)


    def forward(self, x: torch.Tensor):
        if self.return_all_features:
            ret_dict = {}
            
            x_norm = self.norm1(x)
            attn_out = self.attn(x_norm, x_norm, x_norm)
            ret_dict['q'] = attn_out['q']
            ret_dict['k'] = attn_out['k']
            ret_dict['v'] = attn_out['v']
            ret_dict['attn_out'] = attn_out['out']
            x = x + attn_out['out']

            x = x + self.mlp(self.norm2(x))
            ret_dict['out'] = x

            return ret_dict
        
        else:
            x_norm = self.norm1(x)
            x = x + self.attn(x_norm, x_norm, x_norm)
            x = x + self.mlp(self.norm2(x))

            return x


class TransformerDecoderLayer(nn.Module):

    def __init__(
        self,
        in_feature_dim: int = 768,
        qkv_dim: int = 768,
        num_heads: int = 12,
        mlp_factor: float = 4.0,
        mlp_dropout: float = 0.0,
        act: nn.Module = QuickGELU,
    ):
        super().__init__()

        self.attn = Attention(
            q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,
            qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim,
        )

        mlp_dim = round(mlp_factor * in_feature_dim)
        self.mlp = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(in_feature_dim, mlp_dim)),
            ('act', act()),
            ('dropout', nn.Dropout(mlp_dropout)),
            ('fc2', nn.Linear(mlp_dim, in_feature_dim)),
        ]))

        self.norm1 = LayerNorm(in_feature_dim)
        self.norm2 = LayerNorm(in_feature_dim)
        self.norm3 = LayerNorm(in_feature_dim)

        self._initialize_weights()


    def _initialize_weights(self):
        for m in (self.mlp[0], self.mlp[-1]):
            nn.init.xavier_uniform_(m.weight)
            nn.init.normal_(m.bias, std=1e-6)


    def forward(self, x: torch.Tensor, y: torch.Tensor):
        y_norm = self.norm3(y)
        x = x + self.attn(self.norm1(x), y_norm, y_norm)
        x = x + self.mlp(self.norm2(x))

        return x


class VisionTransformer2D(nn.Module):

    def __init__(
        self,
        feature_dim: int = 768,
        input_size: Tuple[int, int] = (224, 224),
        patch_size: Tuple[int, int] = (16, 16),
        num_heads: int = 12,
        num_layers: int = 12,
        mlp_factor: float = 4.0,
        act: nn.Module = QuickGELU,
        return_all_features: bool = False,
        ln_pre: bool = False,
    ):
        super().__init__()

        self.return_all_features = return_all_features
        
        self.patch_embed = PatchEmbed2D(patch_size=patch_size, embed_dim=feature_dim)
        self.num_patches = np.prod([x // y for x, y in zip(input_size, patch_size)]) + 1

        self.cls_token = nn.Parameter(torch.zeros([feature_dim]))
        self.pos_embed = nn.Parameter(torch.zeros([self.num_patches, feature_dim]))

        self.blocks = nn.ModuleList([
            TransformerEncoderLayer(
                in_feature_dim=feature_dim, qkv_dim=feature_dim, num_heads=num_heads, mlp_factor=mlp_factor, act=act,
                return_all_features=return_all_features,
            ) for _ in range(num_layers)
        ])

        if ln_pre:
            self.ln_pre = LayerNorm(feature_dim)
        else:
            self.ln_pre = nn.Identity()

        self._initialize_weights()


    def _initialize_weights(self):
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)

    def forward(self, x: torch.Tensor):
        dtype = self.patch_embed.proj.weight.dtype
        x = x.to(dtype)

        x = self.patch_embed(x)
        x = torch.cat([self.cls_token.view(1, 1, -1).repeat(x.size(0), 1, 1), x], dim=1)
        x = x + self.pos_embed

        x = self.ln_pre(x)

        if self.return_all_features:
            all_features = []
            for blk in self.blocks:
                x = blk(x)
                all_features.append(x)
                x = x['out']
            return all_features
        
        else:
            for blk in self.blocks:
                x = blk(x)
            return x


def model_to_fp16(model: VisionTransformer2D):
    def _module_to_fp16(m: nn.Module):
        if isinstance(m, (nn.Linear,)):
            m.half()
    model.apply(_module_to_fp16)

    model.pos_embed.data = model.pos_embed.data.half()
    model.cls_token.data = model.cls_token.data.half()


vit_presets = {
    'ViT-B/16-lnpre': dict(
        feature_dim=768,
        input_size=(224, 224),
        patch_size=(16, 16),
        num_heads=12,
        num_layers=12,
        mlp_factor=4.0,
        ln_pre=True,
    ),
    'ViT-L/14-lnpre': dict(
        feature_dim=1024,
        input_size=(224, 224),
        patch_size=(14, 14),
        num_heads=16,
        num_layers=24,
        mlp_factor=4.0,
        ln_pre=True,
    ),
}

================================================
FILE: weight_loaders.py
================================================
#!/usr/bin/env python

import os, sys
from typing import Dict

import torch

__all__ = ['weight_loader_fn_dict']

def load_weights_clip(load_path: str) -> Dict[str, torch.Tensor]:
    clip_model = torch.jit.load(load_path, map_location='cpu')
    clip_model = clip_model.visual
    src_state_dict = clip_model.state_dict()
    src_state_dict = dict((k, v.float()) for k, v in src_state_dict.items())

    dst_state_dict = {}
    
    dst_state_dict['cls_token'] = src_state_dict['class_embedding']
    dst_state_dict['pos_embed'] = src_state_dict['positional_embedding']
    dst_state_dict['patch_embed.proj.weight'] = src_state_dict['conv1.weight'].flatten(1)
    dst_state_dict['patch_embed.proj.bias'] = torch.zeros([src_state_dict['conv1.weight'].size(0)])
    
    dst_state_dict['ln_pre.weight'] = src_state_dict['ln_pre.weight']
    dst_state_dict['ln_pre.bias'] = src_state_dict['ln_pre.bias']

    block_idx = 0
    while True:
        src_prefix = 'transformer.resblocks.%d.' % block_idx
        dst_prefix = 'blocks.%d.' % block_idx

        src_block_state_dict = dict((k[len(src_prefix):], v) for k, v in src_state_dict.items() if k.startswith(src_prefix))
        if len(src_block_state_dict) == 0:
            break

        dst_block_state_dict = {}
        feat_dim = src_block_state_dict['ln_1.weight'].size(0)

        for i, dst_name in enumerate(('q', 'k', 'v')):
            dst_block_state_dict['attn.%s_proj.weight' % dst_name] = src_block_state_dict['attn.in_proj_weight'][feat_dim * i: feat_dim * (i + 1)]
            dst_block_state_dict['attn.%s_proj.bias' % dst_name] = src_block_state_dict['attn.in_proj_bias'][feat_dim * i: feat_dim * (i + 1)]
        
        dst_block_state_dict['attn.out_proj.weight'] = src_block_state_dict['attn.out_proj.weight']
        dst_block_state_dict['attn.out_proj.bias'] = src_block_state_dict['attn.out_proj.bias']

        dst_block_state_dict['mlp.fc1.weight'] = src_block_state_dict['mlp.c_fc.weight']
        dst_block_state_dict['mlp.fc1.bias'] = src_block_state_dict['mlp.c_fc.bias']
        dst_block_state_dict['mlp.fc2.weight'] = src_block_state_dict['mlp.c_proj.weight']
        dst_block_state_dict['mlp.fc2.bias'] = src_block_state_dict['mlp.c_proj.bias']

        dst_block_state_dict['norm1.weight'] = src_block_state_dict['ln_1.weight']
        dst_block_state_dict['norm1.bias'] = src_block_state_dict['ln_1.bias']
        dst_block_state_dict['norm2.weight'] = src_block_state_dict['ln_2.weight']
        dst_block_state_dict['norm2.bias'] = src_block_state_dict['ln_2.bias']

        dst_state_dict.update(dict((dst_prefix + k, v) for k, v in dst_block_state_dict.items()))
        block_idx += 1

    return dst_state_dict


weight_loader_fn_dict = {
    'clip': load_weights_clip,
}
Download .txt
gitextract_7dyo2vh1/

├── .gitignore
├── README.md
├── checkpoint.py
├── data/
│   └── k400_class_mappings.json
├── main.py
├── model.py
├── scripts/
│   ├── eval_k400_vitb16_16f_dec4x768.sh
│   ├── eval_k400_vitb16_32f_dec4x768.sh
│   ├── eval_k400_vitb16_8f_dec4x768.sh
│   ├── eval_k400_vitl14_16f_dec4x1024.sh
│   ├── eval_k400_vitl14_32f_dec4x1024.sh
│   ├── eval_k400_vitl14_8f_dec4x1024.sh
│   ├── train_k400_vitb16_16f_dec4x768.sh
│   ├── train_k400_vitb16_32f_dec4x768.sh
│   ├── train_k400_vitb16_8f_dec4x768.sh
│   ├── train_k400_vitl14_16f_dec4x1024.sh
│   ├── train_k400_vitl14_32f_dec4x1024.sh
│   └── train_k400_vitl14_8f_dec4x1024.sh
├── video_dataset/
│   ├── __init__.py
│   ├── dataloader.py
│   ├── dataset.py
│   ├── rand_augment.py
│   ├── random_erasing.py
│   └── transform.py
├── vision_transformer.py
└── weight_loaders.py
Download .txt
SYMBOL INDEX (135 symbols across 10 files)

FILE: checkpoint.py
  function setup_arg_parser (line 10) | def setup_arg_parser(parser: argparse.ArgumentParser):
  function _find_autoresume_path (line 22) | def _find_autoresume_path(args: argparse.Namespace):
  function resume_from_checkpoint (line 46) | def resume_from_checkpoint(
  function save_checkpoint (line 79) | def save_checkpoint(

FILE: main.py
  function setup_print (line 17) | def setup_print(is_master: bool):
  function main (line 33) | def main():
  function evaluate (line 199) | def evaluate(model: torch.nn.Module, loader: torch.utils.data.DataLoader):

FILE: model.py
  class TemporalCrossAttention (line 18) | class TemporalCrossAttention(nn.Module):
    method __init__ (line 20) | def __init__(
    method forward_half (line 44) | def forward_half(self, q: torch.Tensor, k: torch.Tensor, w: torch.Tens...
    method forward (line 60) | def forward(self, q: torch.Tensor, k: torch.Tensor):
  class EVLDecoder (line 71) | class EVLDecoder(nn.Module):
    method __init__ (line 73) | def __init__(
    method _initialize_weights (line 114) | def _initialize_weights(self):
    method forward (line 118) | def forward(self, in_features: List[Dict[str, torch.Tensor]]):
  class EVLTransformer (line 146) | class EVLTransformer(nn.Module):
    method __init__ (line 148) | def __init__(
    method _create_backbone (line 194) | def _create_backbone(
    method _get_backbone (line 220) | def _get_backbone(self, x):
    method forward (line 230) | def forward(self, x: torch.Tensor):

FILE: video_dataset/dataloader.py
  function setup_arg_parser (line 11) | def setup_arg_parser(parser: argparse.ArgumentParser):
  function _parse_mean_and_std (line 59) | def _parse_mean_and_std(args: argparse.Namespace) -> Dict[str, torch.Ten...
  function create_train_dataset (line 75) | def create_train_dataset(args: argparse.Namespace) -> torch.utils.data.D...
  function create_train_loader (line 98) | def create_train_loader(args: argparse.Namespace, resume_step: int = 0) ...
  function create_val_dataset (line 123) | def create_val_dataset(args: argparse.Namespace) -> torch.utils.data.Dat...
  function create_val_loader (line 145) | def create_val_loader(args: argparse.Namespace) -> torch.utils.data.Data...

FILE: video_dataset/dataset.py
  class VideoDataset (line 14) | class VideoDataset(torch.utils.data.Dataset):
    method __init__ (line 16) | def __init__(
    method __len__ (line 46) | def __len__(self):
    method __getitem__ (line 50) | def __getitem__(self, idx):
    method _generate_temporal_crops (line 113) | def _generate_temporal_crops(self, frames):
    method _generate_spatial_crops (line 131) | def _generate_spatial_crops(self, frames):
    method _random_sample_frame_idx (line 155) | def _random_sample_frame_idx(self, len):
  class DummyDataset (line 173) | class DummyDataset(torch.utils.data.Dataset):
    method __init__ (line 175) | def __init__(self, list_path: str, num_frames: int, num_views: int, sp...
    method __len__ (line 182) | def __len__(self):
    method __getitem__ (line 185) | def __getitem__(self, _):

FILE: video_dataset/rand_augment.py
  function _interpolation (line 55) | def _interpolation(kwargs):
  function _check_args_tf (line 63) | def _check_args_tf(kwargs):
  function shear_x (line 69) | def shear_x(img, factor, **kwargs):
  function shear_y (line 76) | def shear_y(img, factor, **kwargs):
  function translate_x_rel (line 83) | def translate_x_rel(img, pct, **kwargs):
  function translate_y_rel (line 91) | def translate_y_rel(img, pct, **kwargs):
  function translate_x_abs (line 99) | def translate_x_abs(img, pixels, **kwargs):
  function translate_y_abs (line 106) | def translate_y_abs(img, pixels, **kwargs):
  function rotate (line 113) | def rotate(img, degrees, **kwargs):
  function auto_contrast (line 147) | def auto_contrast(img, **__):
  function invert (line 151) | def invert(img, **__):
  function equalize (line 155) | def equalize(img, **__):
  function solarize (line 159) | def solarize(img, thresh, **__):
  function solarize_add (line 163) | def solarize_add(img, add, thresh=128, **__):
  function posterize (line 178) | def posterize(img, bits_to_keep, **__):
  function contrast (line 184) | def contrast(img, factor, **__):
  function color (line 188) | def color(img, factor, **__):
  function brightness (line 192) | def brightness(img, factor, **__):
  function sharpness (line 196) | def sharpness(img, factor, **__):
  function _randomly_negate (line 200) | def _randomly_negate(v):
  function _rotate_level_to_arg (line 205) | def _rotate_level_to_arg(level, _hparams):
  function _enhance_level_to_arg (line 212) | def _enhance_level_to_arg(level, _hparams):
  function _enhance_increasing_level_to_arg (line 217) | def _enhance_increasing_level_to_arg(level, _hparams):
  function _shear_level_to_arg (line 225) | def _shear_level_to_arg(level, _hparams):
  function _translate_abs_level_to_arg (line 232) | def _translate_abs_level_to_arg(level, hparams):
  function _translate_rel_level_to_arg (line 239) | def _translate_rel_level_to_arg(level, hparams):
  function _posterize_level_to_arg (line 247) | def _posterize_level_to_arg(level, _hparams):
  function _posterize_increasing_level_to_arg (line 254) | def _posterize_increasing_level_to_arg(level, hparams):
  function _posterize_original_level_to_arg (line 261) | def _posterize_original_level_to_arg(level, _hparams):
  function _solarize_level_to_arg (line 268) | def _solarize_level_to_arg(level, _hparams):
  function _solarize_increasing_level_to_arg (line 274) | def _solarize_increasing_level_to_arg(level, _hparams):
  function _solarize_add_level_to_arg (line 280) | def _solarize_add_level_to_arg(level, _hparams):
  class AugmentOp (line 342) | class AugmentOp:
    method __init__ (line 347) | def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
    method __call__ (line 369) | def __call__(self, img_list):
  function _select_rand_weights (line 449) | def _select_rand_weights(weight_idx=0, transforms=None):
  function rand_augment_ops (line 458) | def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
  class RandAugment (line 467) | class RandAugment:
    method __init__ (line 468) | def __init__(self, ops, num_layers=2, choice_weights=None):
    method __call__ (line 473) | def __call__(self, img):
  function rand_augment_transform (line 486) | def rand_augment_transform(config_str, hparams):

FILE: video_dataset/random_erasing.py
  function _get_pixels (line 21) | def _get_pixels(
  class RandomErasing (line 37) | class RandomErasing:
    method __init__ (line 56) | def __init__(
    method _erase (line 90) | def _erase(self, img, chan, img_h, img_w, dtype):
    method _erase_cube (line 119) | def _erase_cube(
    method __call__ (line 161) | def __call__(self, input):

FILE: video_dataset/transform.py
  function _pil_interp (line 33) | def _pil_interp(method):
  function random_short_side_scale_jitter (line 47) | def random_short_side_scale_jitter(
  function crop_boxes (line 104) | def crop_boxes(boxes, x_offset, y_offset):
  function random_crop (line 123) | def random_crop(images, size, boxes=None):
  function horizontal_flip (line 159) | def horizontal_flip(prob, images, boxes=None):
  function uniform_crop (line 194) | def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
  function clip_boxes_to_image (line 257) | def clip_boxes_to_image(boxes, height, width):
  function blend (line 279) | def blend(images1, images2, alpha):
  function grayscale (line 295) | def grayscale(images):
  function color_jitter (line 317) | def color_jitter(images, img_brightness=0, img_contrast=0, img_saturatio...
  function brightness_jitter (line 352) | def brightness_jitter(var, images):
  function contrast_jitter (line 371) | def contrast_jitter(var, images):
  function saturation_jitter (line 391) | def saturation_jitter(var, images):
  function lighting_jitter (line 410) | def lighting_jitter(images, alphastd, eigval, eigvec):
  function color_normalization (line 458) | def color_normalization(images, mean, stddev):
  function _get_param_spatial_crop (line 502) | def _get_param_spatial_crop(
  function random_resized_crop (line 544) | def random_resized_crop(
  function random_resized_crop_with_shift (line 579) | def random_resized_crop_with_shift(
  function create_random_augment (line 624) | def create_random_augment(
  function random_sized_crop_img (line 660) | def random_sized_crop_img(
  class RandomResizedCropAndInterpolation (line 695) | class RandomResizedCropAndInterpolation:
    method __init__ (line 708) | def __init__(
    method get_params (line 730) | def get_params(img, scale, ratio):
    method __call__ (line 770) | def __call__(self, img):
    method __repr__ (line 784) | def __repr__(self):

FILE: vision_transformer.py
  class QuickGELU (line 15) | class QuickGELU(nn.Module):
    method forward (line 16) | def forward(self, x: torch.Tensor):
  class LayerNorm (line 19) | class LayerNorm(nn.LayerNorm):
    method forward (line 22) | def forward(self, x: torch.Tensor):
  class Attention (line 28) | class Attention(nn.Module):
    method __init__ (line 33) | def __init__(
    method _initialize_weights (line 52) | def _initialize_weights(self):
    method forward (line 58) | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
  class PatchEmbed2D (line 84) | class PatchEmbed2D(nn.Module):
    method __init__ (line 86) | def __init__(
    method _initialize_weights (line 100) | def _initialize_weights(self, x):
    method forward (line 105) | def forward(self, x: torch.Tensor):
  class TransformerEncoderLayer (line 116) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 118) | def __init__(
    method _initialize_weights (line 152) | def _initialize_weights(self):
    method forward (line 158) | def forward(self, x: torch.Tensor):
  class TransformerDecoderLayer (line 183) | class TransformerDecoderLayer(nn.Module):
    method __init__ (line 185) | def __init__(
    method _initialize_weights (line 216) | def _initialize_weights(self):
    method forward (line 222) | def forward(self, x: torch.Tensor, y: torch.Tensor):
  class VisionTransformer2D (line 230) | class VisionTransformer2D(nn.Module):
    method __init__ (line 232) | def __init__(
    method _initialize_weights (line 269) | def _initialize_weights(self):
    method forward (line 273) | def forward(self, x: torch.Tensor):
  function model_to_fp16 (line 297) | def model_to_fp16(model: VisionTransformer2D):

FILE: weight_loaders.py
  function load_weights_clip (line 10) | def load_weights_clip(load_path: str) -> Dict[str, torch.Tensor]:
Condensed preview — 26 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (130K chars).
[
  {
    "path": ".gitignore",
    "chars": 69,
    "preview": "__pycache__\n*.py[cod]\n*$py.class\n\nruns/\n\nvideo_dataset/io_internal.py"
  },
  {
    "path": "README.md",
    "chars": 5553,
    "preview": "# Frozen CLIP models are Efficient Video Learners\n\nThis is the official implementation of the paper [Frozen CLIP models "
  },
  {
    "path": "checkpoint.py",
    "chars": 3710,
    "preview": "#!/usr/bin/env python\n\nimport argparse\nimport os\n\nimport torch\nimport torch.distributed as dist\n\n\ndef setup_arg_parser(p"
  },
  {
    "path": "data/k400_class_mappings.json",
    "chars": 7959,
    "preview": "[\n  \"abseiling\",\n  \"air drumming\",\n  \"answering questions\",\n  \"applauding\",\n  \"applying cream\",\n  \"archery\",\n  \"arm wres"
  },
  {
    "path": "main.py",
    "chars": 10315,
    "preview": "#!/usr/bin/env python\n\nimport argparse\nfrom datetime import datetime\nimport builtins\n\nimport torch\nimport torch.distribu"
  },
  {
    "path": "model.py",
    "chars": 8745,
    "preview": "#!/usr/bin/env python\n\nfrom typing import Dict, Iterable, List, Tuple\nimport numpy as np\n\nimport torch\nimport torch.nn a"
  },
  {
    "path": "scripts/eval_k400_vitb16_16f_dec4x768.sh",
    "chars": 781,
    "preview": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --bac"
  },
  {
    "path": "scripts/eval_k400_vitb16_32f_dec4x768.sh",
    "chars": 780,
    "preview": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --bac"
  },
  {
    "path": "scripts/eval_k400_vitb16_8f_dec4x768.sh",
    "chars": 779,
    "preview": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --bac"
  },
  {
    "path": "scripts/eval_k400_vitl14_16f_dec4x1024.sh",
    "chars": 783,
    "preview": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --bac"
  },
  {
    "path": "scripts/eval_k400_vitl14_32f_dec4x1024.sh",
    "chars": 782,
    "preview": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --bac"
  },
  {
    "path": "scripts/eval_k400_vitl14_8f_dec4x1024.sh",
    "chars": 781,
    "preview": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --bac"
  },
  {
    "path": "scripts/train_k400_vitb16_16f_dec4x768.sh",
    "chars": 914,
    "preview": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitb16_16f_dec4x768\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --npr"
  },
  {
    "path": "scripts/train_k400_vitb16_32f_dec4x768.sh",
    "chars": 913,
    "preview": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitb16_32f_dec4x768\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --npr"
  },
  {
    "path": "scripts/train_k400_vitb16_8f_dec4x768.sh",
    "chars": 912,
    "preview": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitb16_8f_dec4x768\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --npro"
  },
  {
    "path": "scripts/train_k400_vitl14_16f_dec4x1024.sh",
    "chars": 916,
    "preview": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitl14_16f_dec4x1024\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --np"
  },
  {
    "path": "scripts/train_k400_vitl14_32f_dec4x1024.sh",
    "chars": 915,
    "preview": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitl14_32f_dec4x1024\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --np"
  },
  {
    "path": "scripts/train_k400_vitl14_8f_dec4x1024.sh",
    "chars": 914,
    "preview": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitl14_8f_dec4x1024\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --npr"
  },
  {
    "path": "video_dataset/__init__.py",
    "chars": 103,
    "preview": "#!/usr/bin/env python\n\nfrom .dataloader import setup_arg_parser, create_train_loader, create_val_loader"
  },
  {
    "path": "video_dataset/dataloader.py",
    "chars": 6717,
    "preview": "#!/usr/bin/env python\n\nimport argparse\nfrom typing import Dict\n\nimport torch\nimport torch.distributed as dist\n\nfrom .dat"
  },
  {
    "path": "video_dataset/dataset.py",
    "chars": 7185,
    "preview": "#!/usr/bin/env python\n\nimport os, sys\nfrom typing import Optional\nimport av\nimport io\nimport numpy as np\n\nimport torch\nf"
  },
  {
    "path": "video_dataset/rand_augment.py",
    "chars": 16366,
    "preview": "#!/usr/bin/env python\n# Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592"
  },
  {
    "path": "video_dataset/random_erasing.py",
    "chars": 7056,
    "preview": "#!/usr/bin/env python\n# Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592"
  },
  {
    "path": "video_dataset/transform.py",
    "chars": 27344,
    "preview": "#!/usr/bin/env python3\n# Originate from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592"
  },
  {
    "path": "vision_transformer.py",
    "chars": 9515,
    "preview": "#!/usr/bin/env python\n\nfrom collections import OrderedDict\nimport numpy as np\nfrom typing import Tuple\n\nimport torch\nimp"
  },
  {
    "path": "weight_loaders.py",
    "chars": 2768,
    "preview": "#!/usr/bin/env python\n\nimport os, sys\nfrom typing import Dict\n\nimport torch\n\n__all__ = ['weight_loader_fn_dict']\n\ndef lo"
  }
]

About this extraction

This page contains the full source code of the OpenGVLab/efficient-video-recognition GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 26 files (120.7 KB), approximately 33.3k tokens, and a symbol index with 135 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!