[
  {
    "path": ".gitignore",
    "content": "__pycache__\n*.py[cod]\n*$py.class\n\nruns/\n\nvideo_dataset/io_internal.py"
  },
  {
    "path": "README.md",
    "content": "# Frozen CLIP models are Efficient Video Learners\n\nThis is the official implementation of the paper [Frozen CLIP models are Efficient Video Learners](https://arxiv.org/abs/2208.03550)\n\n```\n@article{lin2022frozen,\n  title={Frozen CLIP Models are Efficient Video Learners},\n  author={Lin, Ziyi and Geng, Shijie and Zhang, Renrui and Gao, Peng and de Melo, Gerard and Wang, Xiaogang and Dai, Jifeng and Qiao, Yu and Li, Hongsheng},\n  journal={arXiv preprint arXiv:2208.03550},\n  year={2022}\n}\n```\n\n## Introduction\n\nThe overall architecture of the EVL framework includes a trainable Transformer decoder, trainable local temporal modules and a pretrained, fixed image backbone\n(CLIP is used for instance).\n\n<img src=\"figs/arch.png\" height=\"300\">\n\nUsing a fixed backbone significantly saves training time, and we managed to train a ViT-B/16 with 8 frames for 50 epochs in 60 GPU-hours (NVIDIA V100).\n\nDespite with a small training computation and memory consumption, EVL models achieves high performance on Kinetics-400. A comparison with state-of-the-art methods\nare as follows\n\n<img src=\"figs/k400.png\" height=\"300\">\n\n## Installation\n\nWe tested the released code with the following conda environment\n\n```\nconda create -n pt1.9.0cu11.1_official -c pytorch -c conda-forge pytorch=1.9.0=py3.9_cuda11.1_cudnn8.0.5_0 cudatoolkit torchvision av\n```\n\n## Data Preparation\n\nWe expect that `--train_list_path` and `--val_list_path` command line arguments to be a data list file of the following format\n```\n<path_1> <label_1>\n<path_2> <label_2>\n...\n<path_n> <label_n>\n```\nwhere `<path_i>` points to a video file, and `<label_i>` is an integer between `0` and `num_classes - 1`.\n`--num_classes` should also be specified in the command line argument.\n\nAdditionally, `<path_i>` might be a relative path when `--data_root` is specified, and the actual path will be\nrelative to the path passed as `--data_root`.\n\nThe class mappings in the open-source weights are provided at [Kinetics-400 class mappings](data/k400_class_mappings.json)\n\n## Backbone Preparation\n\nCLIP weights need to be downloaded from [CLIP official repo](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/clip.py#L30)\nand passed to the `--backbone_path` command line argument.\n\n## Script Usage\n\nTraining and evaliation scripts are provided in the scripts folder.\nScripts should be ready to run once the environment is setup and \n`--backbone_path`, `--train_list_path` and `--val_list_path` are replaced with your own paths.\n\nFor other command line arguments please see the help message for usage.\n\n## Kinetics-400 Main Results\n\nThis is a re-implementation for open-source use.\nWe are still re-running some models, and their scripts, weights and logs will be released later.\nIn the following table we report the re-run accuracy, which may be slightly different from the original paper (typically +/-0.1%)\n\n| Backbone | Decoder Layers | #frames x stride | top-1 | top-5 | Script | Model | Log |\n| - | - | - | - | - | - | - | - |\n| ViT-B/16 | 4 | 8 x 16 | 82.8 | 95.8 | [script](scripts/train_k400_vitb16_8f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1DoGjvDdkJoSa9i-wq1lh6QoEZIa4xTB3/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1-9vgsXMpnWBI9MxQV7SSQhkPfLomoYY3/view?usp=sharing) |\n| ViT-B/16 | 4 | 16 x 16 | 83.7 | 96.2 | [script](scripts/train_k400_vitb16_16f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1dax4qUIOEI_QzYXv31J-87cDkonQetVQ/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1l2ivY28jUpwSmafQZvwtUo7tvm42i0PL/view?usp=sharing) |\n| ViT-B/16 | 4 | 32 x 8 | 84.3 | 96.6 | [script](scripts/train_k400_vitb16_32f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1fzFM5pD39Kfp8xRAJuWaXR9RALLmnoeU/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1X1ZOdSCxXVeMpNhr_bviNKlRfJa5SMD7/view?usp=sharing) |\n| ViT-L/14 | 4 | 8 x 16 | 86.3 | 97.2 | [script](scripts/train_k400_vitl14_8f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1AkdF4CkOVW2uiycCVqCxS397oYxNISAI/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1OJFBmaE_tAwTzG-4i0CLQmhwGnN0psx1/view?usp=sharing) |\n| ViT-L/14 | 4 | 16 x 16 | 86.9 | 97.4 | [script](scripts/train_k400_vitl14_16f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1CTV9geLD3HLWzByAQUOf_m0F_g2lE3rg/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1a2iC4tQvjWFMI3UrEv2chuHwVrF6p9YF/view?usp=sharing) |\n| ViT-L/14 | 4 | 32 x 8 | 87.7 | 97.6 | [script](scripts/train_k400_vitl14_32f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1zNFNCKwP5owakELlnTCD20cpVQBqgJrB/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1dK7qoz3McYrmfS09FfreXC-LjUM7l0u4/view?usp=sharing) |\n| ViT-L/14 (336px) | 4 | 32 x 8 | 87.7 | 97.8 | | | |\n\n## Data Loading Speed\n\nAs the training process is fast, video frames are consumed at a very high rate.\nFor easier installation, the current version uses PyTorch-builtin data loaders.\nThey are not very efficient and can become a bottleneck when using ViT-B as backbones.\nWe provide a `--dummy_dataset` option to bypass actual video decoding for training speed measurement. \nThe model accuracy should not be affected. \nOur internal data loader is pure C++-based and does not bottleneck training by much on a machine with 2x Xeon Gold 6148 CPUs and 4x V100 GPUs.\n\n\n## Acknowledgements\n\nThe data loader code is modified from [PySlowFast](https://github.com/facebookresearch/SlowFast). Thanks for their awesome work!\n"
  },
  {
    "path": "checkpoint.py",
    "content": "#!/usr/bin/env python\n\nimport argparse\nimport os\n\nimport torch\nimport torch.distributed as dist\n\n\ndef setup_arg_parser(parser: argparse.ArgumentParser):\n    parser.add_argument('--checkpoint_dir', type=str,\n                        help='checkpoint output path')\n    parser.add_argument('--auto_resume', action='store_true',\n                        help='auto resume from the last checkpoint from checkpoint_dir')\n    parser.add_argument('--resume_path', type=str,\n                        help='resume from manually specified checkpoint file, overriding auto_resume')\n    parser.add_argument('--pretrain', type=str,\n                        help='path to pretrained weights. will NOT override auto_resume of resume_path, '\n                             'load optimizer state or enforce strict matching of checkpoint and model weights.')\n\n\ndef _find_autoresume_path(args: argparse.Namespace):\n    print('Trying to auto resume from path:', args.checkpoint_dir)\n\n    if os.path.isdir(args.checkpoint_dir):\n        checkpoint_files = [x for x in os.listdir(args.checkpoint_dir) if x.startswith('checkpoint-') and x.endswith('.pth')]\n        checkpoint_iters = []\n        for x in checkpoint_files:\n            try:\n                x = x[len('checkpoint-'): -len('.pth')]\n                x = int(x)\n            except ValueError:\n                continue\n            checkpoint_iters.append(x)\n    else:\n        checkpoint_iters = []\n\n    if len(checkpoint_iters) == 0:\n        print('Did not find a valid checkpoint file.')\n    else:\n        checkpoint_iters.sort()\n        args.resume_path = os.path.join(args.checkpoint_dir, 'checkpoint-%d.pth' % checkpoint_iters[-1])\n        print(f'Found {len(checkpoint_iters)} checkpoint file(s).')\n\n\ndef resume_from_checkpoint(\n    model: torch.nn.Module,\n    optimizer: torch.optim.Optimizer,\n    lr_sched: torch.optim.lr_scheduler._LRScheduler,\n    loss_scaler: torch.cuda.amp.grad_scaler.GradScaler,\n    args: argparse.Namespace,\n) -> int:\n    if args.pretrain is not None:\n        print(f'Loading pretrain model: {args.pretrain}')\n        ckpt = torch.load(args.pretrain, map_location='cpu')\n        print(model.load_state_dict(ckpt['model'], strict=False))\n\n    # returns resume_step on successful resume, or 0 otherwise.\n    if args.auto_resume and args.resume_path is None:\n        _find_autoresume_path(args)\n    \n    if args.resume_path is None:\n        print('Not resuming from a checkpoint.')\n        return 0\n    else:\n        print(f'Resuming from checkpoint file {args.resume_path}')\n        ckpt = torch.load(args.resume_path, map_location='cpu')\n        model.load_state_dict(ckpt['model'], strict=True)\n        if 'optimizer' in ckpt:\n            optimizer.load_state_dict(ckpt['optimizer'])\n            lr_sched.load_state_dict(ckpt['lr_sched'])\n            loss_scaler.load_state_dict(ckpt['loss_scaler'])\n            return ckpt['next_step']\n        else:\n            print('Optimizer state is NOT found in checkpoint.')\n            return 0\n\n\ndef save_checkpoint(\n    model: torch.nn.Module,\n    optimizer: torch.optim.Optimizer,\n    lr_sched: torch.optim.lr_scheduler._LRScheduler,\n    loss_scaler: torch.cuda.amp.grad_scaler.GradScaler,\n    next_step: int,\n    args: argparse.Namespace,\n):\n    if args.checkpoint_dir is None:\n        return\n\n    if not os.path.isdir(args.checkpoint_dir):\n        os.makedirs(args.checkpoint_dir)\n    \n    to_save = {\n        'model': model.state_dict(),\n        'optimizer': optimizer.state_dict(),\n        'lr_sched': lr_sched.state_dict(),\n        'loss_scaler': loss_scaler.state_dict(),\n        'next_step': next_step,\n    }\n    torch.save(to_save, os.path.join(args.checkpoint_dir, f'checkpoint-{next_step}.pth'))\n"
  },
  {
    "path": "data/k400_class_mappings.json",
    "content": "[\n  \"abseiling\",\n  \"air drumming\",\n  \"answering questions\",\n  \"applauding\",\n  \"applying cream\",\n  \"archery\",\n  \"arm wrestling\",\n  \"arranging flowers\",\n  \"assembling computer\",\n  \"auctioning\",\n  \"baby waking up\",\n  \"baking cookies\",\n  \"balloon blowing\",\n  \"bandaging\",\n  \"barbequing\",\n  \"bartending\",\n  \"beatboxing\",\n  \"bee keeping\",\n  \"belly dancing\",\n  \"bench pressing\",\n  \"bending back\",\n  \"bending metal\",\n  \"biking through snow\",\n  \"blasting sand\",\n  \"blowing glass\",\n  \"blowing leaves\",\n  \"blowing nose\",\n  \"blowing out candles\",\n  \"bobsledding\",\n  \"bookbinding\",\n  \"bouncing on trampoline\",\n  \"bowling\",\n  \"braiding hair\",\n  \"breading or breadcrumbing\",\n  \"breakdancing\",\n  \"brush painting\",\n  \"brushing hair\",\n  \"brushing teeth\",\n  \"building cabinet\",\n  \"building shed\",\n  \"bungee jumping\",\n  \"busking\",\n  \"canoeing or kayaking\",\n  \"capoeira\",\n  \"carrying baby\",\n  \"cartwheeling\",\n  \"carving pumpkin\",\n  \"catching fish\",\n  \"catching or throwing baseball\",\n  \"catching or throwing frisbee\",\n  \"catching or throwing softball\",\n  \"celebrating\",\n  \"changing oil\",\n  \"changing wheel\",\n  \"checking tires\",\n  \"cheerleading\",\n  \"chopping wood\",\n  \"clapping\",\n  \"clay pottery making\",\n  \"clean and jerk\",\n  \"cleaning floor\",\n  \"cleaning gutters\",\n  \"cleaning pool\",\n  \"cleaning shoes\",\n  \"cleaning toilet\",\n  \"cleaning windows\",\n  \"climbing a rope\",\n  \"climbing ladder\",\n  \"climbing tree\",\n  \"contact juggling\",\n  \"cooking chicken\",\n  \"cooking egg\",\n  \"cooking on campfire\",\n  \"cooking sausages\",\n  \"counting money\",\n  \"country line dancing\",\n  \"cracking neck\",\n  \"crawling baby\",\n  \"crossing river\",\n  \"crying\",\n  \"curling hair\",\n  \"cutting nails\",\n  \"cutting pineapple\",\n  \"cutting watermelon\",\n  \"dancing ballet\",\n  \"dancing charleston\",\n  \"dancing gangnam style\",\n  \"dancing macarena\",\n  \"deadlifting\",\n  \"decorating the christmas tree\",\n  \"digging\",\n  \"dining\",\n  \"disc golfing\",\n  \"diving cliff\",\n  \"dodgeball\",\n  \"doing aerobics\",\n  \"doing laundry\",\n  \"doing nails\",\n  \"drawing\",\n  \"dribbling basketball\",\n  \"drinking\",\n  \"drinking beer\",\n  \"drinking shots\",\n  \"driving car\",\n  \"driving tractor\",\n  \"drop kicking\",\n  \"drumming fingers\",\n  \"dunking basketball\",\n  \"dying hair\",\n  \"eating burger\",\n  \"eating cake\",\n  \"eating carrots\",\n  \"eating chips\",\n  \"eating doughnuts\",\n  \"eating hotdog\",\n  \"eating ice cream\",\n  \"eating spaghetti\",\n  \"eating watermelon\",\n  \"egg hunting\",\n  \"exercising arm\",\n  \"exercising with an exercise ball\",\n  \"extinguishing fire\",\n  \"faceplanting\",\n  \"feeding birds\",\n  \"feeding fish\",\n  \"feeding goats\",\n  \"filling eyebrows\",\n  \"finger snapping\",\n  \"fixing hair\",\n  \"flipping pancake\",\n  \"flying kite\",\n  \"folding clothes\",\n  \"folding napkins\",\n  \"folding paper\",\n  \"front raises\",\n  \"frying vegetables\",\n  \"garbage collecting\",\n  \"gargling\",\n  \"getting a haircut\",\n  \"getting a tattoo\",\n  \"giving or receiving award\",\n  \"golf chipping\",\n  \"golf driving\",\n  \"golf putting\",\n  \"grinding meat\",\n  \"grooming dog\",\n  \"grooming horse\",\n  \"gymnastics tumbling\",\n  \"hammer throw\",\n  \"headbanging\",\n  \"headbutting\",\n  \"high jump\",\n  \"high kick\",\n  \"hitting baseball\",\n  \"hockey stop\",\n  \"holding snake\",\n  \"hopscotch\",\n  \"hoverboarding\",\n  \"hugging\",\n  \"hula hooping\",\n  \"hurdling\",\n  \"hurling (sport)\",\n  \"ice climbing\",\n  \"ice fishing\",\n  \"ice skating\",\n  \"ironing\",\n  \"javelin throw\",\n  \"jetskiing\",\n  \"jogging\",\n  \"juggling balls\",\n  \"juggling fire\",\n  \"juggling soccer ball\",\n  \"jumping into pool\",\n  \"jumpstyle dancing\",\n  \"kicking field goal\",\n  \"kicking soccer ball\",\n  \"kissing\",\n  \"kitesurfing\",\n  \"knitting\",\n  \"krumping\",\n  \"laughing\",\n  \"laying bricks\",\n  \"long jump\",\n  \"lunge\",\n  \"making a cake\",\n  \"making a sandwich\",\n  \"making bed\",\n  \"making jewelry\",\n  \"making pizza\",\n  \"making snowman\",\n  \"making sushi\",\n  \"making tea\",\n  \"marching\",\n  \"massaging back\",\n  \"massaging feet\",\n  \"massaging legs\",\n  \"massaging person's head\",\n  \"milking cow\",\n  \"mopping floor\",\n  \"motorcycling\",\n  \"moving furniture\",\n  \"mowing lawn\",\n  \"news anchoring\",\n  \"opening bottle\",\n  \"opening present\",\n  \"paragliding\",\n  \"parasailing\",\n  \"parkour\",\n  \"passing American football (in game)\",\n  \"passing American football (not in game)\",\n  \"peeling apples\",\n  \"peeling potatoes\",\n  \"petting animal (not cat)\",\n  \"petting cat\",\n  \"picking fruit\",\n  \"planting trees\",\n  \"plastering\",\n  \"playing accordion\",\n  \"playing badminton\",\n  \"playing bagpipes\",\n  \"playing basketball\",\n  \"playing bass guitar\",\n  \"playing cards\",\n  \"playing cello\",\n  \"playing chess\",\n  \"playing clarinet\",\n  \"playing controller\",\n  \"playing cricket\",\n  \"playing cymbals\",\n  \"playing didgeridoo\",\n  \"playing drums\",\n  \"playing flute\",\n  \"playing guitar\",\n  \"playing harmonica\",\n  \"playing harp\",\n  \"playing ice hockey\",\n  \"playing keyboard\",\n  \"playing kickball\",\n  \"playing monopoly\",\n  \"playing organ\",\n  \"playing paintball\",\n  \"playing piano\",\n  \"playing poker\",\n  \"playing recorder\",\n  \"playing saxophone\",\n  \"playing squash or racquetball\",\n  \"playing tennis\",\n  \"playing trombone\",\n  \"playing trumpet\",\n  \"playing ukulele\",\n  \"playing violin\",\n  \"playing volleyball\",\n  \"playing xylophone\",\n  \"pole vault\",\n  \"presenting weather forecast\",\n  \"pull ups\",\n  \"pumping fist\",\n  \"pumping gas\",\n  \"punching bag\",\n  \"punching person (boxing)\",\n  \"push up\",\n  \"pushing car\",\n  \"pushing cart\",\n  \"pushing wheelchair\",\n  \"reading book\",\n  \"reading newspaper\",\n  \"recording music\",\n  \"riding a bike\",\n  \"riding camel\",\n  \"riding elephant\",\n  \"riding mechanical bull\",\n  \"riding mountain bike\",\n  \"riding mule\",\n  \"riding or walking with horse\",\n  \"riding scooter\",\n  \"riding unicycle\",\n  \"ripping paper\",\n  \"robot dancing\",\n  \"rock climbing\",\n  \"rock scissors paper\",\n  \"roller skating\",\n  \"running on treadmill\",\n  \"sailing\",\n  \"salsa dancing\",\n  \"sanding floor\",\n  \"scrambling eggs\",\n  \"scuba diving\",\n  \"setting table\",\n  \"shaking hands\",\n  \"shaking head\",\n  \"sharpening knives\",\n  \"sharpening pencil\",\n  \"shaving head\",\n  \"shaving legs\",\n  \"shearing sheep\",\n  \"shining shoes\",\n  \"shooting basketball\",\n  \"shooting goal (soccer)\",\n  \"shot put\",\n  \"shoveling snow\",\n  \"shredding paper\",\n  \"shuffling cards\",\n  \"side kick\",\n  \"sign language interpreting\",\n  \"singing\",\n  \"situp\",\n  \"skateboarding\",\n  \"ski jumping\",\n  \"skiing (not slalom or crosscountry)\",\n  \"skiing crosscountry\",\n  \"skiing slalom\",\n  \"skipping rope\",\n  \"skydiving\",\n  \"slacklining\",\n  \"slapping\",\n  \"sled dog racing\",\n  \"smoking\",\n  \"smoking hookah\",\n  \"snatch weight lifting\",\n  \"sneezing\",\n  \"sniffing\",\n  \"snorkeling\",\n  \"snowboarding\",\n  \"snowkiting\",\n  \"snowmobiling\",\n  \"somersaulting\",\n  \"spinning poi\",\n  \"spray painting\",\n  \"spraying\",\n  \"springboard diving\",\n  \"squat\",\n  \"sticking tongue out\",\n  \"stomping grapes\",\n  \"stretching arm\",\n  \"stretching leg\",\n  \"strumming guitar\",\n  \"surfing crowd\",\n  \"surfing water\",\n  \"sweeping floor\",\n  \"swimming backstroke\",\n  \"swimming breast stroke\",\n  \"swimming butterfly stroke\",\n  \"swing dancing\",\n  \"swinging legs\",\n  \"swinging on something\",\n  \"sword fighting\",\n  \"tai chi\",\n  \"taking a shower\",\n  \"tango dancing\",\n  \"tap dancing\",\n  \"tapping guitar\",\n  \"tapping pen\",\n  \"tasting beer\",\n  \"tasting food\",\n  \"testifying\",\n  \"texting\",\n  \"throwing axe\",\n  \"throwing ball\",\n  \"throwing discus\",\n  \"tickling\",\n  \"tobogganing\",\n  \"tossing coin\",\n  \"tossing salad\",\n  \"training dog\",\n  \"trapezing\",\n  \"trimming or shaving beard\",\n  \"trimming trees\",\n  \"triple jump\",\n  \"tying bow tie\",\n  \"tying knot (not on a tie)\",\n  \"tying tie\",\n  \"unboxing\",\n  \"unloading truck\",\n  \"using computer\",\n  \"using remote controller (not gaming)\",\n  \"using segway\",\n  \"vault\",\n  \"waiting in line\",\n  \"walking the dog\",\n  \"washing dishes\",\n  \"washing feet\",\n  \"washing hair\",\n  \"washing hands\",\n  \"water skiing\",\n  \"water sliding\",\n  \"watering plants\",\n  \"waxing back\",\n  \"waxing chest\",\n  \"waxing eyebrows\",\n  \"waxing legs\",\n  \"weaving basket\",\n  \"welding\",\n  \"whistling\",\n  \"windsurfing\",\n  \"wrapping present\",\n  \"wrestling\",\n  \"writing\",\n  \"yawning\",\n  \"yoga\",\n  \"zumba\"\n]\n"
  },
  {
    "path": "main.py",
    "content": "#!/usr/bin/env python\n\nimport argparse\nfrom datetime import datetime\nimport builtins\n\nimport torch\nimport torch.distributed as dist\n\nimport video_dataset\nimport checkpoint\nfrom model import EVLTransformer\nfrom video_dataset import dataloader\nfrom weight_loaders import weight_loader_fn_dict\nfrom vision_transformer import vit_presets\n\ndef setup_print(is_master: bool):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    builtin_print = builtins.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        if is_master or force:\n            now = datetime.now().time()\n            builtin_print('[{}] '.format(now), end='')  # print with time stamp\n            builtin_print(*args, **kwargs)\n\n    builtins.print = print\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    \n    video_dataset.setup_arg_parser(parser)\n    checkpoint.setup_arg_parser(parser)\n\n    parser.add_argument('--num_steps', type=int,\n                        help='number of training steps')\n    parser.add_argument('--eval_only', action='store_true',\n                        help='run evaluation only')\n    parser.add_argument('--save_freq', type=int, default=5000,\n                        help='save a checkpoint every N steps')\n    parser.add_argument('--eval_freq', type=int, default=5000,\n                        help='evaluate every N steps')\n    parser.add_argument('--print_freq', type=int, default=10,\n                        help='print log message every N steps')\n\n    parser.add_argument('--backbone', type=str, choices=vit_presets.keys(), default='ViT-B/16-lnpre',\n                        help='the backbone variant used to generate image feature maps')\n    parser.add_argument('--backbone_path', type=str,\n                        help='path to pretrained backbone weights')\n    parser.add_argument('--backbone_type', type=str, default='clip', choices=weight_loader_fn_dict.keys(),\n                        help='type of backbone weights (used to determine how to convert state_dict from different pretraining codebase)')\n    parser.add_argument('--finetune_backbone', action='store_true',\n                        help='finetune backbone weights')\n    parser.add_argument('--decoder_num_layers', type=int, default=4,\n                        help='number of decoder layers')\n    parser.add_argument('--decoder_qkv_dim', type=int, default=768,\n                        help='q (k, v) projection output dimensions in decoder attention layers')\n    parser.add_argument('--decoder_num_heads', type=int, default=12,\n                        help='number of heads in decoder attention layers')\n    parser.add_argument('--decoder_mlp_factor', type=float, default=4.0,\n                        help='expansion factor of feature dimension in the middle of decoder MLPs')\n    parser.add_argument('--num_classes', type=int, default=400,\n                        help='number of classes')\n    parser.add_argument('--cls_dropout', type=float, default=0.5,\n                        help='dropout rate applied before the final classification linear projection')\n    parser.add_argument('--decoder_mlp_dropout', type=float, default=0.5,\n                        help='dropout rate applied in MLP layers in the decoder')\n    parser.add_argument('--no_temporal_conv', action='store_false', dest='temporal_conv',\n                        help='disable temporal convolution on frame features')\n    parser.add_argument('--no_temporal_pos_embed', action='store_false', dest='temporal_pos_embed',\n                        help='disable temporal position embeddings added to frame features')\n    parser.add_argument('--no_temporal_cross_attention', action='store_false', dest='temporal_cross_attention',\n                        help='disable temporal cross attention on frame query and key features')\n    parser.set_defaults(temporal_conv=True, temporal_pos_embed=True, temporal_cross_attention=True)\n\n    parser.add_argument('--lr', type=float, default=4e-4,\n                        help='learning rate')\n    parser.add_argument('--weight_decay', type=float, default=0.05,\n                        help='optimizer weight decay')\n    parser.add_argument('--disable_fp16', action='store_false', dest='fp16',\n                        help='disable fp16 during training or inference')\n    parser.set_defaults(fp16=True)\n\n    parser.add_argument('--batch_split', type=int, default=1,\n                        help='optionally split the batch into smaller shards and forward/backward one shard '\n                             'at a time to avoid out-of-memory error.')\n\n    args = parser.parse_args()\n\n    dist.init_process_group('nccl')\n    setup_print(dist.get_rank() == 0)\n    cuda_device_id = dist.get_rank() % torch.cuda.device_count()\n    torch.cuda.set_device(cuda_device_id)\n\n    model = EVLTransformer(\n        backbone_name=args.backbone,\n        backbone_type=args.backbone_type,\n        backbone_path=args.backbone_path,\n        backbone_mode='finetune' if args.finetune_backbone else ('freeze_fp16' if args.fp16 else 'freeze_fp32'),\n        decoder_num_layers=args.decoder_num_layers,\n        decoder_qkv_dim=args.decoder_qkv_dim,\n        decoder_num_heads=args.decoder_num_heads,\n        decoder_mlp_factor=args.decoder_mlp_factor,\n        num_classes=args.num_classes,\n        enable_temporal_conv=args.temporal_conv,\n        enable_temporal_pos_embed=args.temporal_pos_embed,\n        enable_temporal_cross_attention=args.temporal_cross_attention,\n        cls_dropout=args.cls_dropout,\n        decoder_mlp_dropout=args.decoder_mlp_dropout,\n        num_frames=args.num_frames,\n    )\n    print(model)\n    model.cuda()\n    model = torch.nn.parallel.DistributedDataParallel(\n        model, device_ids=[cuda_device_id], output_device=cuda_device_id,\n    )\n\n    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)\n    lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps)\n    loss_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=args.fp16)\n    criterion = torch.nn.CrossEntropyLoss()\n\n    resume_step = checkpoint.resume_from_checkpoint(model, optimizer, lr_sched, loss_scaler, args)\n\n    val_loader = video_dataset.create_val_loader(args)\n    if args.eval_only:\n        print('Running in eval_only mode.')\n        model.eval()\n        evaluate(model, val_loader)\n        return\n    else:\n        assert args.train_list_path is not None, 'Train list path must be specified if not in eval_only mode.'\n        train_loader = video_dataset.create_train_loader(args, resume_step=resume_step)\n\n    assert len(train_loader) == args.num_steps - resume_step\n    batch_st, train_st = datetime.now(), datetime.now()\n    for i, (data, labels) in enumerate(train_loader, resume_step):\n        data, labels = data.cuda(), labels.cuda()\n        data_ed = datetime.now()\n\n        optimizer.zero_grad()\n\n        assert data.size(0) % args.batch_split == 0\n        split_size = data.size(0) // args.batch_split\n        hit1, hit5, loss_value = 0, 0, 0\n        for j in range(args.batch_split):\n            data_slice = data[split_size * j: split_size * (j + 1)]\n            labels_slice = labels[split_size * j: split_size * (j + 1)]\n\n            with torch.cuda.amp.autocast(args.fp16):\n                logits = model(data_slice)\n                loss = criterion(logits, labels_slice)\n                \n            if labels.dtype == torch.long: # no mixup, can calculate accuracy\n                hit1 += (logits.topk(1, dim=1)[1] == labels_slice.view(-1, 1)).sum().item()\n                hit5 += (logits.topk(5, dim=1)[1] == labels_slice.view(-1, 1)).sum().item()\n            loss_value += loss.item() / args.batch_split\n            \n            loss_scaler.scale(loss / args.batch_split).backward()\n        \n        loss_scaler.step(optimizer)\n        loss_scaler.update()\n        lr_sched.step()\n\n        batch_ed = datetime.now()\n\n        if i % args.print_freq == 0:\n            sync_tensor = torch.Tensor([loss_value, hit1 / data.size(0), hit5 / data.size(0)]).cuda()\n            dist.all_reduce(sync_tensor)\n            sync_tensor = sync_tensor.cpu() / dist.get_world_size()\n            loss_value, acc1, acc5 = sync_tensor.tolist()\n\n            print(\n                f'batch_time: {(batch_ed - batch_st).total_seconds():.3f}  '\n                f'data_time: {(data_ed - batch_st).total_seconds():.3f}  '\n                f'ETA: {(batch_ed - train_st) / (i - resume_step + 1) * (args.num_steps - i - 1)}  |  '\n                f'lr: {optimizer.param_groups[0][\"lr\"]:.6f}  '\n                f'loss: {loss_value:.6f}' + (\n                    f'  acc1: {acc1 * 100:.2f}%  acc5: {acc5 * 100:.2f}%' if labels.dtype == torch.long else ''\n                )\n            )\n        \n        if (i + 1) % args.eval_freq == 0:\n            print('Start model evaluation at step', i + 1)\n            model.eval()\n            evaluate(model, val_loader)\n            model.train()\n\n        if (i + 1) % args.save_freq == 0:\n            checkpoint.save_checkpoint(model, optimizer, lr_sched, loss_scaler, i + 1, args)\n        \n        batch_st = datetime.now()\n\n\ndef evaluate(model: torch.nn.Module, loader: torch.utils.data.DataLoader):\n    tot, hit1, hit5 = 0, 0, 0\n    eval_st = datetime.now()\n    for data, labels in loader:\n        data, labels = data.cuda(), labels.cuda()\n        assert data.size(0) == 1\n        if data.ndim == 6:\n            data = data[0] # now the first dimension is number of views\n\n        with torch.no_grad():\n            logits = model(data)\n            scores = logits.softmax(dim=-1).mean(dim=0)\n\n        tot += 1\n        hit1 += (scores.topk(1)[1] == labels).sum().item()\n        hit5 += (scores.topk(5)[1] == labels).sum().item()\n\n        if tot % 20 == 0:\n            print(f'[Evaluation] num_samples: {tot}  '\n                  f'ETA: {(datetime.now() - eval_st) / tot * (len(loader) - tot)}  '\n                  f'cumulative_acc1: {hit1 / tot * 100.:.2f}%  '\n                  f'cumulative_acc5: {hit5 / tot * 100.:.2f}%')\n\n    sync_tensor = torch.LongTensor([tot, hit1, hit5]).cuda()\n    dist.all_reduce(sync_tensor)\n    tot, hit1, hit5 = sync_tensor.cpu().tolist()\n\n    print(f'Accuracy on validation set: top1={hit1 / tot * 100:.2f}%, top5={hit5 / tot * 100:.2f}%')\n\n\nif __name__ == '__main__': main()\n"
  },
  {
    "path": "model.py",
    "content": "#!/usr/bin/env python\n\nfrom typing import Dict, Iterable, List, Tuple\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom vision_transformer import QuickGELU, Attention\nfrom weight_loaders import weight_loader_fn_dict\nfrom vision_transformer import (\n    VisionTransformer2D, TransformerDecoderLayer,\n    model_to_fp16, vit_presets,\n)\n        \n\nclass TemporalCrossAttention(nn.Module):\n\n    def __init__(\n        self,\n        spatial_size: Tuple[int, int] = (14, 14),\n        feature_dim: int = 768,\n    ):\n        super().__init__()\n\n        self.spatial_size = spatial_size\n\n        w_size = np.prod([x * 2 - 1 for x in spatial_size])\n        self.w1 = nn.Parameter(torch.zeros([w_size, feature_dim]))\n        self.w2 = nn.Parameter(torch.zeros([w_size, feature_dim]))\n\n        idx_tensor = torch.zeros([np.prod(spatial_size) for _ in (0, 1)], dtype=torch.long)\n        for q in range(np.prod(spatial_size)):\n            qi, qj = q // spatial_size[1], q % spatial_size[1]\n            for k in range(np.prod(spatial_size)):\n                ki, kj = k // spatial_size[1], k % spatial_size[1]\n                i_offs = qi - ki + spatial_size[0] - 1\n                j_offs = qj - kj + spatial_size[1] - 1\n                idx_tensor[q, k] = i_offs * (spatial_size[1] * 2 - 1) + j_offs\n        self.idx_tensor = idx_tensor\n\n\n    def forward_half(self, q: torch.Tensor, k: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n        q, k = q[:, :, 1:], k[:, :, 1:] # remove cls token\n\n        assert q.size() == k.size()\n        assert q.size(2) == np.prod(self.spatial_size)\n\n        attn = torch.einsum('ntqhd,ntkhd->ntqkh', q / (q.size(-1) ** 0.5), k)\n        attn = attn.softmax(dim=-2).mean(dim=-1) # L, L, N, T\n\n        self.idx_tensor = self.idx_tensor.to(w.device)\n        w_unroll = w[self.idx_tensor] # L, L, C\n        ret = torch.einsum('ntqk,qkc->ntqc', attn, w_unroll)\n\n        return ret\n\n\n    def forward(self, q: torch.Tensor, k: torch.Tensor):\n        N, T, L, H, D = q.size()\n        assert L == np.prod(self.spatial_size) + 1\n\n        ret = torch.zeros([N, T, L, self.w1.size(-1)], device='cuda')\n        ret[:, 1:, 1:, :] += self.forward_half(q[:, 1:, :, :, :], k[:, :-1, :, :, :], self.w1)\n        ret[:, :-1, 1:, :] += self.forward_half(q[:, :-1, :, :, :], k[:, 1:, :, :, :], self.w2)\n\n        return ret\n\n\nclass EVLDecoder(nn.Module):\n\n    def __init__(\n        self,\n        num_frames: int = 8,\n        spatial_size: Tuple[int, int] = (14, 14),\n        num_layers: int = 4,\n        in_feature_dim: int = 768,\n        qkv_dim: int = 768,\n        num_heads: int = 12,\n        mlp_factor: float = 4.0,\n        enable_temporal_conv: bool = True,\n        enable_temporal_pos_embed: bool = True,\n        enable_temporal_cross_attention: bool = True,\n        mlp_dropout: float = 0.5,\n    ):\n        super().__init__()\n\n        self.enable_temporal_conv = enable_temporal_conv\n        self.enable_temporal_pos_embed = enable_temporal_pos_embed\n        self.enable_temporal_cross_attention = enable_temporal_cross_attention\n        self.num_layers = num_layers\n\n        self.decoder_layers = nn.ModuleList(\n            [TransformerDecoderLayer(in_feature_dim, qkv_dim, num_heads, mlp_factor, mlp_dropout) for _ in range(num_layers)]\n        )\n\n        if enable_temporal_conv:\n            self.temporal_conv = nn.ModuleList(\n                [nn.Conv1d(in_feature_dim, in_feature_dim, kernel_size=3, stride=1, padding=1, groups=in_feature_dim) for _ in range(num_layers)]\n            )\n        if enable_temporal_pos_embed:\n            self.temporal_pos_embed = nn.ParameterList(\n                [nn.Parameter(torch.zeros([num_frames, in_feature_dim])) for _ in range(num_layers)]\n            )\n        if enable_temporal_cross_attention:\n            self.cross_attention = nn.ModuleList(\n                [TemporalCrossAttention(spatial_size, in_feature_dim) for _ in range(num_layers)]\n            )\n\n        self.cls_token = nn.Parameter(torch.zeros([in_feature_dim]))\n\n\n    def _initialize_weights(self):\n        nn.init.normal_(self.cls_token, std=0.02)\n\n\n    def forward(self, in_features: List[Dict[str, torch.Tensor]]):\n        N, T, L, C = in_features[0]['out'].size()\n        assert len(in_features) == self.num_layers\n        x = self.cls_token.view(1, 1, -1).repeat(N, 1, 1)\n\n        for i in range(self.num_layers):\n            frame_features = in_features[i]['out']\n            \n            if self.enable_temporal_conv:\n                feat = in_features[i]['out']\n                feat = feat.permute(0, 2, 3, 1).contiguous().flatten(0, 1) # N * L, C, T\n                feat = self.temporal_conv[i](feat)\n                feat = feat.view(N, L, C, T).permute(0, 3, 1, 2).contiguous() # N, T, L, C\n                frame_features += feat\n            \n            if self.enable_temporal_pos_embed:\n                frame_features += self.temporal_pos_embed[i].view(1, T, 1, C)\n            \n            if self.enable_temporal_cross_attention:\n                frame_features += self.cross_attention[i](in_features[i]['q'], in_features[i]['k'])\n\n            frame_features = frame_features.flatten(1, 2) # N, T * L, C\n            \n            x = self.decoder_layers[i](x, frame_features)\n        \n        return x\n\n\nclass EVLTransformer(nn.Module):\n\n    def __init__(\n        self,\n        num_frames: int = 8,\n        backbone_name: str = 'ViT-B/16',\n        backbone_type: str = 'clip',\n        backbone_path: str = '',\n        backbone_mode: str = 'frozen_fp16',\n        decoder_num_layers: int = 4,\n        decoder_qkv_dim: int = 768,\n        decoder_num_heads: int = 12,\n        decoder_mlp_factor: float = 4.0,\n        num_classes: int = 400,\n        enable_temporal_conv: bool = True,\n        enable_temporal_pos_embed: bool = True,\n        enable_temporal_cross_attention: bool = True,\n        cls_dropout: float = 0.5,\n        decoder_mlp_dropout: float = 0.5,\n    ):\n        super().__init__()\n\n        self.decoder_num_layers = decoder_num_layers\n\n        backbone_config = self._create_backbone(backbone_name, backbone_type, backbone_path, backbone_mode)\n        backbone_feature_dim = backbone_config['feature_dim']\n        backbone_spatial_size = tuple(x // y for x, y in zip(backbone_config['input_size'], backbone_config['patch_size']))\n\n        self.decoder = EVLDecoder(\n            num_frames=num_frames,\n            spatial_size=backbone_spatial_size,\n            num_layers=decoder_num_layers,\n            in_feature_dim=backbone_feature_dim,\n            qkv_dim=decoder_qkv_dim,\n            num_heads=decoder_num_heads,\n            mlp_factor=decoder_mlp_factor,\n            enable_temporal_conv=enable_temporal_conv,\n            enable_temporal_pos_embed=enable_temporal_pos_embed,\n            enable_temporal_cross_attention=enable_temporal_cross_attention,\n            mlp_dropout=decoder_mlp_dropout,\n        )\n        self.proj = nn.Sequential(\n            nn.LayerNorm(backbone_feature_dim),\n            nn.Dropout(cls_dropout),\n            nn.Linear(backbone_feature_dim, num_classes),\n        )\n\n\n    def _create_backbone(\n        self,\n        backbone_name: str,\n        backbone_type: str,\n        backbone_path: str,\n        backbone_mode: str,\n    ) -> dict:\n        weight_loader_fn = weight_loader_fn_dict[backbone_type]\n        state_dict = weight_loader_fn(backbone_path)\n\n        backbone = VisionTransformer2D(return_all_features=True, **vit_presets[backbone_name])\n        backbone.load_state_dict(state_dict, strict=True) # weight_loader_fn is expected to strip unused parameters\n\n        assert backbone_mode in ['finetune', 'freeze_fp16', 'freeze_fp32']\n\n        if backbone_mode == 'finetune':\n            self.backbone = backbone\n        else:\n            backbone.eval().requires_grad_(False)\n            if backbone_mode == 'freeze_fp16':\n                model_to_fp16(backbone)\n            self.backbone = [backbone] # avoid backbone parameter registration\n\n        return vit_presets[backbone_name]\n\n\n    def _get_backbone(self, x):\n        if isinstance(self.backbone, list):\n            # freeze backbone\n            self.backbone[0] = self.backbone[0].to(x.device)\n            return self.backbone[0]\n        else:\n            # finetune bakbone\n            return self.backbone\n\n\n    def forward(self, x: torch.Tensor):\n        backbone = self._get_backbone(x)\n\n        B, C, T, H, W = x.size()\n        x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)\n        features = backbone(x)[-self.decoder_num_layers:]\n        features = [\n            dict((k, v.float().view(B, T, *v.size()[1:])) for k, v in x.items())\n            for x in features\n        ]\n\n        x = self.decoder(features)\n        x = self.proj(x[:, 0, :])\n\n        return x"
  },
  {
    "path": "scripts/eval_k400_vitb16_16f_dec4x768.sh",
    "content": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-B/16-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-B-16.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 768 \\\n    --decoder_num_heads 12 \\\n    --num_classes 400 \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 16 \\\n    --sampling_rate 16 \\\n    --num_spatial_views 3 \\\n    --num_temporal_views 1 \\\n    --resume_path /path/to/checkpoint_release/k400_vitb16_16f_dec4x768.pth \\\n    --eval_only\n"
  },
  {
    "path": "scripts/eval_k400_vitb16_32f_dec4x768.sh",
    "content": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-B/16-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-B-16.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 768 \\\n    --decoder_num_heads 12 \\\n    --num_classes 400 \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 32 \\\n    --sampling_rate 8 \\\n    --num_spatial_views 3 \\\n    --num_temporal_views 1 \\\n    --resume_path /path/to/checkpoint_release/k400_vitb16_32f_dec4x768.pth \\\n    --eval_only\n"
  },
  {
    "path": "scripts/eval_k400_vitb16_8f_dec4x768.sh",
    "content": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-B/16-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-B-16.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 768 \\\n    --decoder_num_heads 12 \\\n    --num_classes 400 \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 8 \\\n    --sampling_rate 16 \\\n    --num_spatial_views 1 \\\n    --num_temporal_views 3 \\\n    --resume_path /path/to/checkpoint_release/k400_vitb16_8f_dec4x768.pth \\\n    --eval_only\n"
  },
  {
    "path": "scripts/eval_k400_vitl14_16f_dec4x1024.sh",
    "content": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-L/14-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-L-14.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 1024 \\\n    --decoder_num_heads 16 \\\n    --num_classes 400 \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 16 \\\n    --sampling_rate 16 \\\n    --num_spatial_views 3 \\\n    --num_temporal_views 1 \\\n    --resume_path /path/to/checkpoint_release/k400_vitl14_16f_dec4x1024.pth \\\n    --eval_only\n"
  },
  {
    "path": "scripts/eval_k400_vitl14_32f_dec4x1024.sh",
    "content": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-L/14-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-L-14.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 1024 \\\n    --decoder_num_heads 16 \\\n    --num_classes 400 \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 32 \\\n    --sampling_rate 8 \\\n    --num_spatial_views 3 \\\n    --num_temporal_views 1 \\\n    --resume_path /path/to/checkpoint_release/k400_vitl14_32f_dec4x1024.pth \\\n    --eval_only\n"
  },
  {
    "path": "scripts/eval_k400_vitl14_8f_dec4x1024.sh",
    "content": "#!/usr/bin/env sh\n\npython -u -m torch.distributed.run --nproc_per_node 4 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-L/14-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-L-14.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 1024 \\\n    --decoder_num_heads 16 \\\n    --num_classes 400 \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 8 \\\n    --sampling_rate 16 \\\n    --num_spatial_views 1 \\\n    --num_temporal_views 3 \\\n    --resume_path /path/to/checkpoint_release/k400_vitl14_8f_dec4x1024.pth \\\n    --eval_only\n"
  },
  {
    "path": "scripts/train_k400_vitb16_16f_dec4x768.sh",
    "content": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitb16_16f_dec4x768\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --nproc_per_node 8 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-B/16-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-B-16.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 768 \\\n    --decoder_num_heads 12 \\\n    --num_classes 400 \\\n    --checkpoint_dir \"${exp_dir}\" \\\n    --auto_resume \\\n    --train_list_path /path/to/k400/train.txt \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 16 \\\n    --sampling_rate 16 \\\n    --num_spatial_views 3 \\\n    --num_temporal_views 1 \\\n  2>&1 | tee \"${exp_dir}/train-$(date +\"%Y%m%d_%H%M%S\").log\"\n"
  },
  {
    "path": "scripts/train_k400_vitb16_32f_dec4x768.sh",
    "content": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitb16_32f_dec4x768\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --nproc_per_node 8 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-B/16-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-B-16.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 768 \\\n    --decoder_num_heads 12 \\\n    --num_classes 400 \\\n    --checkpoint_dir \"${exp_dir}\" \\\n    --auto_resume \\\n    --train_list_path /path/to/k400/train.txt \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 32 \\\n    --sampling_rate 8 \\\n    --num_spatial_views 3 \\\n    --num_temporal_views 1 \\\n  2>&1 | tee \"${exp_dir}/train-$(date +\"%Y%m%d_%H%M%S\").log\"\n"
  },
  {
    "path": "scripts/train_k400_vitb16_8f_dec4x768.sh",
    "content": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitb16_8f_dec4x768\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --nproc_per_node 8 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-B/16-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-B-16.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 768 \\\n    --decoder_num_heads 12 \\\n    --num_classes 400 \\\n    --checkpoint_dir \"${exp_dir}\" \\\n    --auto_resume \\\n    --train_list_path /path/to/k400/train.txt \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 8 \\\n    --sampling_rate 16 \\\n    --num_spatial_views 1 \\\n    --num_temporal_views 3 \\\n  2>&1 | tee \"${exp_dir}/train-$(date +\"%Y%m%d_%H%M%S\").log\"\n"
  },
  {
    "path": "scripts/train_k400_vitl14_16f_dec4x1024.sh",
    "content": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitl14_16f_dec4x1024\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --nproc_per_node 8 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-L/14-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-L-14.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 1024 \\\n    --decoder_num_heads 16 \\\n    --num_classes 400 \\\n    --checkpoint_dir \"${exp_dir}\" \\\n    --auto_resume \\\n    --train_list_path /path/to/k400/train.txt \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 2 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 16 \\\n    --sampling_rate 16 \\\n    --num_spatial_views 3 \\\n    --num_temporal_views 1 \\\n  2>&1 | tee \"${exp_dir}/train-$(date +\"%Y%m%d_%H%M%S\").log\"\n"
  },
  {
    "path": "scripts/train_k400_vitl14_32f_dec4x1024.sh",
    "content": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitl14_32f_dec4x1024\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --nproc_per_node 8 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-L/14-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-L-14.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 1024 \\\n    --decoder_num_heads 16 \\\n    --num_classes 400 \\\n    --checkpoint_dir \"${exp_dir}\" \\\n    --auto_resume \\\n    --train_list_path /path/to/k400/train.txt \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 4 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 32 \\\n    --sampling_rate 8 \\\n    --num_spatial_views 3 \\\n    --num_temporal_views 1 \\\n  2>&1 | tee \"${exp_dir}/train-$(date +\"%Y%m%d_%H%M%S\").log\"\n"
  },
  {
    "path": "scripts/train_k400_vitl14_8f_dec4x1024.sh",
    "content": "#!/usr/bin/env sh\n\nexp_dir=runs/k400_vitl14_8f_dec4x1024\n\nmkdir -p \"${exp_dir}\"\npython -u -m torch.distributed.run --nproc_per_node 8 \\\n  main.py \\\n    --num_steps 50000 \\\n    --backbone \"ViT-L/14-lnpre\" \\\n    --backbone_type clip \\\n    --backbone_path /path/to/clip_models/ViT-L-14.pt \\\n    --decoder_num_layers 4 \\\n    --decoder_qkv_dim 1024 \\\n    --decoder_num_heads 16 \\\n    --num_classes 400 \\\n    --checkpoint_dir \"${exp_dir}\" \\\n    --auto_resume \\\n    --train_list_path /path/to/k400/train.txt \\\n    --val_list_path /path/to/k400/val.txt \\\n    --batch_size 256 \\\n    --batch_split 1 \\\n    --auto_augment rand-m7-n4-mstd0.5-inc1 \\\n    --mean 0.48145466 0.4578275 0.40821073 \\\n    --std 0.26862954 0.26130258 0.27577711 \\\n    --num_workers 12 \\\n    --num_frames 8 \\\n    --sampling_rate 16 \\\n    --num_spatial_views 1 \\\n    --num_temporal_views 3 \\\n  2>&1 | tee \"${exp_dir}/train-$(date +\"%Y%m%d_%H%M%S\").log\"\n"
  },
  {
    "path": "video_dataset/__init__.py",
    "content": "#!/usr/bin/env python\n\nfrom .dataloader import setup_arg_parser, create_train_loader, create_val_loader"
  },
  {
    "path": "video_dataset/dataloader.py",
    "content": "#!/usr/bin/env python\n\nimport argparse\nfrom typing import Dict\n\nimport torch\nimport torch.distributed as dist\n\nfrom .dataset import VideoDataset, DummyDataset\n\ndef setup_arg_parser(parser: argparse.ArgumentParser):\n    parser.add_argument('--train_list_path', type=str,\n                        help='path to training data list')\n    parser.add_argument('--val_list_path', type=str,\n                        help='path to validation data list')\n    parser.add_argument('--train_data_root', type=str,\n                        help='training samples root directory')\n    parser.add_argument('--val_data_root', type=str,\n                        help='validation samples root directory')\n    parser.add_argument('--data_root', type=str, default='',\n                        help='training and validation samples root directory, might be overrided by --train_data_root or --val_data_root')\n\n    parser.add_argument('--batch_size', type=int,\n                        help='training batch size on a all GPUs')\n\n    parser.add_argument('--num_spatial_views', type=int, default=1,\n                        help='number of spatial crops used for testing (total views = num_spatial_views * num_temporal_views)')\n    parser.add_argument('--num_temporal_views', type=int, default=3,\n                        help='number of temporal crops used for testing (total views = num_spatial_views * num_temporal_views)')\n    parser.add_argument('--num_frames', type=int, default=8,\n                        help='number of frames used for each view')\n    parser.add_argument('--sampling_rate', type=int, default=16,\n                        help='temporal stride for frame sampling, only valid when tsn_sampling is not enabled')\n    parser.add_argument('--tsn_sampling', action='store_true',\n                        help='enable TSN-style sampling (i.e. sample frames with dynamic stride to cover the whole video)')\n    parser.add_argument('--spatial_size', type=int, default=224,\n                        help='frame height and width in pixels')\n\n    parser.add_argument('--mean', type=float, nargs='+',\n                        help='pixel mean used to normalize the image.')\n    parser.add_argument('--std', type=float, nargs='+',\n                        help='pixel std used to normalize the image')\n\n    parser.add_argument('--num_workers', type=int, default=10,\n                        help='number of DataLoader worker threads')\n    \n    parser.add_argument('--dummy_dataset', action='store_true',\n                        help='use fake datasets that generate all 0 (use for speed test only)')\n\n    parser.add_argument('--auto_augment', type=str,\n                        help='auto augment configuration')\n    parser.add_argument('--interpolation', type=str, default='bicubic',\n                        help='interpolation mode')\n    parser.add_argument('--no_mirror', action='store_false', dest='mirror',\n                        help='disable mirror for training (frequently used for the something-something dataset)')\n    parser.set_defaults(mirror=True)\n                        \n\ndef _parse_mean_and_std(args: argparse.Namespace) -> Dict[str, torch.Tensor]:\n    def parse_mean_or_std(arg, default_value):\n        if arg is None:\n            return torch.Tensor([default_value] * 3)\n        elif len(arg) == 1:\n            return torch.Tensor(arg * 3)\n        elif len(arg) == 3:\n            return torch.Tensor(arg)\n        else:\n            raise NotImplementedError()\n    return {\n        'mean': parse_mean_or_std(args.mean, 0.45),\n        'std': parse_mean_or_std(args.std, 0.225),\n    }\n\n\ndef create_train_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset:\n    if args.dummy_dataset:\n        return DummyDataset(\n            list_path=args.train_list_path,\n            num_frames=args.num_frames,\n            num_views=1,\n            spatial_size=args.spatial_size,\n        )\n\n    return VideoDataset(\n        list_path=args.train_list_path,\n        data_root=args.train_data_root or args.data_root,\n        num_spatial_views=1, num_temporal_views=1, random_sample=True,\n        auto_augment=args.auto_augment,\n        interpolation=args.interpolation,\n        mirror=args.mirror,\n        num_frames=args.num_frames,\n        sampling_rate=-1 if args.tsn_sampling else args.sampling_rate,\n        spatial_size=args.spatial_size,\n        **_parse_mean_and_std(args),\n    )\n\n\ndef create_train_loader(args: argparse.Namespace, resume_step: int = 0) -> torch.utils.data.DataLoader:\n    dataset = create_train_dataset(args)\n    rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size())\n\n    assert args.batch_size % world_size == 0\n    batch_size_per_gpu = args.batch_size // world_size\n\n    # manually create a step-based sampler\n    sampler = []\n    while len(sampler) * len(dataset) < args.num_steps * args.batch_size:\n        g = torch.Generator()\n        g.manual_seed(len(sampler))\n        indices = torch.randperm(len(dataset), generator=g)\n        sampler.append(indices)\n    sampler = torch.cat(sampler, dim=0)[:args.num_steps * args.batch_size].view(args.num_steps, args.batch_size)\n    sampler = sampler[resume_step:, batch_size_per_gpu * rank: batch_size_per_gpu * (rank + 1)].flatten().tolist()\n\n    loader = torch.utils.data.DataLoader(\n        dataset, sampler=sampler, batch_size=batch_size_per_gpu,\n        num_workers=args.num_workers, pin_memory=False, drop_last=True,\n    )\n\n    return loader\n\n\ndef create_val_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset:\n    if args.dummy_dataset:\n        return DummyDataset(\n            list_path=args.val_list_path,\n            num_frames=args.num_frames,\n            num_views=args.num_spatial_views * args.num_temporal_views,\n            spatial_size=args.spatial_size,\n        )\n\n    return VideoDataset(\n        list_path=args.val_list_path,\n        data_root=args.val_data_root or args.data_root,\n        num_spatial_views=args.num_spatial_views,\n        num_temporal_views=args.num_temporal_views,\n        random_sample=False,\n        num_frames=args.num_frames,\n        sampling_rate=-1 if args.tsn_sampling else args.sampling_rate,\n        spatial_size=args.spatial_size,\n        **_parse_mean_and_std(args),\n    )\n\n\ndef create_val_loader(args: argparse.Namespace) -> torch.utils.data.Dataset:\n    dataset = create_val_dataset(args)\n    rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size())\n\n    # sampler for distribued eval\n    sampler = list(range(rank, len(dataset), world_size))\n\n    loader = torch.utils.data.DataLoader(\n        dataset, sampler=sampler, batch_size=1,\n        num_workers=args.num_workers, pin_memory=False,\n    )\n\n    return loader\n"
  },
  {
    "path": "video_dataset/dataset.py",
    "content": "#!/usr/bin/env python\n\nimport os, sys\nfrom typing import Optional\nimport av\nimport io\nimport numpy as np\n\nimport torch\nfrom torchvision import transforms\n\nfrom .transform import create_random_augment, random_resized_crop\n\nclass VideoDataset(torch.utils.data.Dataset):\n\n    def __init__(\n        self, list_path: str, data_root: str,\n        num_spatial_views: int, num_temporal_views: int, random_sample: bool,\n        num_frames: int, sampling_rate: int, spatial_size: int,\n        mean: torch.Tensor, std: torch.Tensor,\n        auto_augment: Optional[str] = None, interpolation: str = 'bicubic',\n        mirror: bool = False,\n    ):\n        self.data_root = data_root\n        self.interpolation = interpolation\n        self.spatial_size = spatial_size\n\n        self.mean, self.std = mean, std\n        self.num_frames, self.sampling_rate = num_frames, sampling_rate\n\n        if random_sample:\n            assert num_spatial_views == 1 and num_temporal_views == 1\n            self.random_sample = True\n            self.mirror = mirror\n            self.auto_augment = auto_augment\n        else:\n            assert auto_augment is None and not mirror\n            self.random_sample = False\n            self.num_temporal_views = num_temporal_views\n            self.num_spatial_views = num_spatial_views\n\n        with open(list_path) as f:\n            self.data_list = f.read().splitlines()\n\n\n    def __len__(self):\n        return len(self.data_list)\n    \n\n    def __getitem__(self, idx):\n        line = self.data_list[idx]\n        path, label = line.split(' ')\n        path = os.path.join(self.data_root, path)\n        label = int(label)\n\n        container = av.open(path)\n        frames = {}\n        for frame in container.decode(video=0):\n            frames[frame.pts] = frame\n        container.close()\n        frames = [frames[k] for k in sorted(frames.keys())]\n\n        if self.random_sample:\n            frame_idx = self._random_sample_frame_idx(len(frames))\n            frames = [frames[x].to_rgb().to_ndarray() for x in frame_idx]\n            frames = torch.as_tensor(np.stack(frames)).float() / 255.\n\n            if self.auto_augment is not None:\n                aug_transform = create_random_augment(\n                    input_size=(frames.size(1), frames.size(2)),\n                    auto_augment=self.auto_augment,\n                    interpolation=self.interpolation,\n                )\n                frames = frames.permute(0, 3, 1, 2) # T, C, H, W\n                frames = [transforms.ToPILImage()(frames[i]) for i in range(frames.size(0))]\n                frames = aug_transform(frames)\n                frames = torch.stack([transforms.ToTensor()(img) for img in frames])\n                frames = frames.permute(0, 2, 3, 1)\n\n            frames = (frames - self.mean) / self.std\n            frames = frames.permute(3, 0, 1, 2) # C, T, H, W\n            frames = random_resized_crop(\n                frames, self.spatial_size, self.spatial_size,\n            )\n            \n        else:\n            frames = [x.to_rgb().to_ndarray() for x in frames]\n            frames = torch.as_tensor(np.stack(frames))\n            frames = frames.float() / 255.\n\n            frames = (frames - self.mean) / self.std\n            frames = frames.permute(3, 0, 1, 2) # C, T, H, W\n            \n            if frames.size(-2) < frames.size(-1):\n                new_width = frames.size(-1) * self.spatial_size // frames.size(-2)\n                new_height = self.spatial_size\n            else:\n                new_height = frames.size(-2) * self.spatial_size // frames.size(-1)\n                new_width = self.spatial_size\n            frames = torch.nn.functional.interpolate(\n                frames, size=(new_height, new_width),\n                mode='bilinear', align_corners=False,\n            )\n\n            frames = self._generate_spatial_crops(frames)\n            frames = sum([self._generate_temporal_crops(x) for x in frames], [])\n            if len(frames) > 1:\n                frames = torch.stack(frames)\n\n        return frames, label\n\n\n    def _generate_temporal_crops(self, frames):\n        seg_len = (self.num_frames - 1) * self.sampling_rate + 1\n        if frames.size(1) < seg_len:\n            frames = torch.cat([frames, frames[:, -1:].repeat(1, seg_len - frames.size(1), 1, 1)], dim=1)\n        slide_len = frames.size(1) - seg_len\n\n        crops = []\n        for i in range(self.num_temporal_views):\n            if self.num_temporal_views == 1:\n                st = slide_len // 2\n            else:\n                st = round(slide_len / (self.num_temporal_views - 1) * i)\n\n            crops.append(frames[:, st: st + self.num_frames * self.sampling_rate: self.sampling_rate])\n        \n        return crops\n\n\n    def _generate_spatial_crops(self, frames):\n        if self.num_spatial_views == 1:\n            assert min(frames.size(-2), frames.size(-1)) >= self.spatial_size\n            h_st = (frames.size(-2) - self.spatial_size) // 2\n            w_st = (frames.size(-1) - self.spatial_size) // 2\n            h_ed, w_ed = h_st + self.spatial_size, w_st + self.spatial_size\n            return [frames[:, :, h_st: h_ed, w_st: w_ed]]\n\n        elif self.num_spatial_views == 3:\n            assert min(frames.size(-2), frames.size(-1)) == self.spatial_size\n            crops = []\n            margin = max(frames.size(-2), frames.size(-1)) - self.spatial_size\n            for st in (0, margin // 2, margin):\n                ed = st + self.spatial_size\n                if frames.size(-2) > frames.size(-1):\n                    crops.append(frames[:, :, st: ed, :])\n                else:\n                    crops.append(frames[:, :, :, st: ed])\n            return crops\n        \n        else:\n            raise NotImplementedError()\n\n\n    def _random_sample_frame_idx(self, len):\n        frame_indices = []\n\n        if self.sampling_rate < 0: # tsn sample\n            seg_size = (len - 1) / self.num_frames\n            for i in range(self.num_frames):\n                start, end = round(seg_size * i), round(seg_size * (i + 1))\n                frame_indices.append(np.random.randint(start, end + 1))\n        elif self.sampling_rate * (self.num_frames - 1) + 1 >= len:\n            for i in range(self.num_frames):\n                frame_indices.append(i * self.sampling_rate if i * self.sampling_rate < len else frame_indices[-1])\n        else:\n            start = np.random.randint(len - self.sampling_rate * (self.num_frames - 1))\n            frame_indices = list(range(start, start + self.sampling_rate * self.num_frames, self.sampling_rate))\n\n        return frame_indices\n\n\nclass DummyDataset(torch.utils.data.Dataset):\n\n    def __init__(self, list_path: str, num_frames: int, num_views: int, spatial_size: int):\n        with open(list_path) as f:\n            self.len = len(f.read().splitlines())\n        self.num_frames = num_frames\n        self.num_views = num_views\n        self.spatial_size = spatial_size\n\n    def __len__(self):\n        return self.len\n\n    def __getitem__(self, _):\n        shape = [3, self.num_frames, self.spatial_size, self.spatial_size]\n        if self.num_views != 1:\n            shape = [self.num_views] + shape\n        return torch.zeros(shape), 0\n"
  },
  {
    "path": "video_dataset/rand_augment.py",
    "content": "#!/usr/bin/env python\n# Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/rand_augment.py\n\n# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n\n\"\"\"\nThis implementation is based on\nhttps://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py\npulished under an Apache License 2.0.\n\nCOMMENT FROM ORIGINAL:\nAutoAugment, RandAugment, and AugMix for PyTorch\nThis code implements the searched ImageNet policies with various tweaks and\nimprovements and does not include any of the search code. AA and RA\nImplementation adapted from:\n    https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py\nAugMix adapted from:\n    https://github.com/google-research/augmix\nPapers:\n    AutoAugment: Learning Augmentation Policies from Data\n    https://arxiv.org/abs/1805.09501\n    Learning Data Augmentation Strategies for Object Detection\n    https://arxiv.org/abs/1906.11172\n    RandAugment: Practical automated data augmentation...\n    https://arxiv.org/abs/1909.13719\n    AugMix: A Simple Data Processing Method to Improve Robustness and\n    Uncertainty https://arxiv.org/abs/1912.02781\n\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\n\nimport math\nimport numpy as np\nimport random\nimport re\nimport PIL\nfrom PIL import Image, ImageEnhance, ImageOps\n\n_PIL_VER = tuple([int(x) for x in PIL.__version__.split(\".\")[:2]])\n\n_FILL = (128, 128, 128)\n\n# This signifies the max integer that the controller RNN could predict for the\n# augmentation scheme.\n_MAX_LEVEL = 10.0\n\n_HPARAMS_DEFAULT = {\n    \"translate_const\": 250,\n    \"img_mean\": _FILL,\n}\n\n_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)\n\n\ndef _interpolation(kwargs):\n    interpolation = kwargs.pop(\"resample\", Image.BILINEAR)\n    if isinstance(interpolation, (list, tuple)):\n        return random.choice(interpolation)\n    else:\n        return interpolation\n\n\ndef _check_args_tf(kwargs):\n    if \"fillcolor\" in kwargs and _PIL_VER < (5, 0):\n        kwargs.pop(\"fillcolor\")\n    kwargs[\"resample\"] = _interpolation(kwargs)\n\n\ndef shear_x(img, factor, **kwargs):\n    _check_args_tf(kwargs)\n    return img.transform(\n        img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs\n    )\n\n\ndef shear_y(img, factor, **kwargs):\n    _check_args_tf(kwargs)\n    return img.transform(\n        img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs\n    )\n\n\ndef translate_x_rel(img, pct, **kwargs):\n    pixels = pct * img.size[0]\n    _check_args_tf(kwargs)\n    return img.transform(\n        img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs\n    )\n\n\ndef translate_y_rel(img, pct, **kwargs):\n    pixels = pct * img.size[1]\n    _check_args_tf(kwargs)\n    return img.transform(\n        img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs\n    )\n\n\ndef translate_x_abs(img, pixels, **kwargs):\n    _check_args_tf(kwargs)\n    return img.transform(\n        img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs\n    )\n\n\ndef translate_y_abs(img, pixels, **kwargs):\n    _check_args_tf(kwargs)\n    return img.transform(\n        img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs\n    )\n\n\ndef rotate(img, degrees, **kwargs):\n    _check_args_tf(kwargs)\n    if _PIL_VER >= (5, 2):\n        return img.rotate(degrees, **kwargs)\n    elif _PIL_VER >= (5, 0):\n        w, h = img.size\n        post_trans = (0, 0)\n        rotn_center = (w / 2.0, h / 2.0)\n        angle = -math.radians(degrees)\n        matrix = [\n            round(math.cos(angle), 15),\n            round(math.sin(angle), 15),\n            0.0,\n            round(-math.sin(angle), 15),\n            round(math.cos(angle), 15),\n            0.0,\n        ]\n\n        def transform(x, y, matrix):\n            (a, b, c, d, e, f) = matrix\n            return a * x + b * y + c, d * x + e * y + f\n\n        matrix[2], matrix[5] = transform(\n            -rotn_center[0] - post_trans[0],\n            -rotn_center[1] - post_trans[1],\n            matrix,\n        )\n        matrix[2] += rotn_center[0]\n        matrix[5] += rotn_center[1]\n        return img.transform(img.size, Image.AFFINE, matrix, **kwargs)\n    else:\n        return img.rotate(degrees, resample=kwargs[\"resample\"])\n\n\ndef auto_contrast(img, **__):\n    return ImageOps.autocontrast(img)\n\n\ndef invert(img, **__):\n    return ImageOps.invert(img)\n\n\ndef equalize(img, **__):\n    return ImageOps.equalize(img)\n\n\ndef solarize(img, thresh, **__):\n    return ImageOps.solarize(img, thresh)\n\n\ndef solarize_add(img, add, thresh=128, **__):\n    lut = []\n    for i in range(256):\n        if i < thresh:\n            lut.append(min(255, i + add))\n        else:\n            lut.append(i)\n    if img.mode in (\"L\", \"RGB\"):\n        if img.mode == \"RGB\" and len(lut) == 256:\n            lut = lut + lut + lut\n        return img.point(lut)\n    else:\n        return img\n\n\ndef posterize(img, bits_to_keep, **__):\n    if bits_to_keep >= 8:\n        return img\n    return ImageOps.posterize(img, bits_to_keep)\n\n\ndef contrast(img, factor, **__):\n    return ImageEnhance.Contrast(img).enhance(factor)\n\n\ndef color(img, factor, **__):\n    return ImageEnhance.Color(img).enhance(factor)\n\n\ndef brightness(img, factor, **__):\n    return ImageEnhance.Brightness(img).enhance(factor)\n\n\ndef sharpness(img, factor, **__):\n    return ImageEnhance.Sharpness(img).enhance(factor)\n\n\ndef _randomly_negate(v):\n    \"\"\"With 50% prob, negate the value\"\"\"\n    return -v if random.random() > 0.5 else v\n\n\ndef _rotate_level_to_arg(level, _hparams):\n    # range [-30, 30]\n    level = (level / _MAX_LEVEL) * 30.0\n    level = _randomly_negate(level)\n    return (level,)\n\n\ndef _enhance_level_to_arg(level, _hparams):\n    # range [0.1, 1.9]\n    return ((level / _MAX_LEVEL) * 1.8 + 0.1,)\n\n\ndef _enhance_increasing_level_to_arg(level, _hparams):\n    # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend\n    # range [0.1, 1.9]\n    level = (level / _MAX_LEVEL) * 0.9\n    level = 1.0 + _randomly_negate(level)\n    return (level,)\n\n\ndef _shear_level_to_arg(level, _hparams):\n    # range [-0.3, 0.3]\n    level = (level / _MAX_LEVEL) * 0.3\n    level = _randomly_negate(level)\n    return (level,)\n\n\ndef _translate_abs_level_to_arg(level, hparams):\n    translate_const = hparams[\"translate_const\"]\n    level = (level / _MAX_LEVEL) * float(translate_const)\n    level = _randomly_negate(level)\n    return (level,)\n\n\ndef _translate_rel_level_to_arg(level, hparams):\n    # default range [-0.45, 0.45]\n    translate_pct = hparams.get(\"translate_pct\", 0.45)\n    level = (level / _MAX_LEVEL) * translate_pct\n    level = _randomly_negate(level)\n    return (level,)\n\n\ndef _posterize_level_to_arg(level, _hparams):\n    # As per Tensorflow TPU EfficientNet impl\n    # range [0, 4], 'keep 0 up to 4 MSB of original image'\n    # intensity/severity of augmentation decreases with level\n    return (int((level / _MAX_LEVEL) * 4),)\n\n\ndef _posterize_increasing_level_to_arg(level, hparams):\n    # As per Tensorflow models research and UDA impl\n    # range [4, 0], 'keep 4 down to 0 MSB of original image',\n    # intensity/severity of augmentation increases with level\n    return (4 - _posterize_level_to_arg(level, hparams)[0],)\n\n\ndef _posterize_original_level_to_arg(level, _hparams):\n    # As per original AutoAugment paper description\n    # range [4, 8], 'keep 4 up to 8 MSB of image'\n    # intensity/severity of augmentation decreases with level\n    return (int((level / _MAX_LEVEL) * 4) + 4,)\n\n\ndef _solarize_level_to_arg(level, _hparams):\n    # range [0, 256]\n    # intensity/severity of augmentation decreases with level\n    return (int((level / _MAX_LEVEL) * 256),)\n\n\ndef _solarize_increasing_level_to_arg(level, _hparams):\n    # range [0, 256]\n    # intensity/severity of augmentation increases with level\n    return (256 - _solarize_level_to_arg(level, _hparams)[0],)\n\n\ndef _solarize_add_level_to_arg(level, _hparams):\n    # range [0, 110]\n    return (int((level / _MAX_LEVEL) * 110),)\n\n\nLEVEL_TO_ARG = {\n    \"AutoContrast\": None,\n    \"Equalize\": None,\n    \"Invert\": None,\n    \"Rotate\": _rotate_level_to_arg,\n    # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers\n    \"Posterize\": _posterize_level_to_arg,\n    \"PosterizeIncreasing\": _posterize_increasing_level_to_arg,\n    \"PosterizeOriginal\": _posterize_original_level_to_arg,\n    \"Solarize\": _solarize_level_to_arg,\n    \"SolarizeIncreasing\": _solarize_increasing_level_to_arg,\n    \"SolarizeAdd\": _solarize_add_level_to_arg,\n    \"Color\": _enhance_level_to_arg,\n    \"ColorIncreasing\": _enhance_increasing_level_to_arg,\n    \"Contrast\": _enhance_level_to_arg,\n    \"ContrastIncreasing\": _enhance_increasing_level_to_arg,\n    \"Brightness\": _enhance_level_to_arg,\n    \"BrightnessIncreasing\": _enhance_increasing_level_to_arg,\n    \"Sharpness\": _enhance_level_to_arg,\n    \"SharpnessIncreasing\": _enhance_increasing_level_to_arg,\n    \"ShearX\": _shear_level_to_arg,\n    \"ShearY\": _shear_level_to_arg,\n    \"TranslateX\": _translate_abs_level_to_arg,\n    \"TranslateY\": _translate_abs_level_to_arg,\n    \"TranslateXRel\": _translate_rel_level_to_arg,\n    \"TranslateYRel\": _translate_rel_level_to_arg,\n}\n\n\nNAME_TO_OP = {\n    \"AutoContrast\": auto_contrast,\n    \"Equalize\": equalize,\n    \"Invert\": invert,\n    \"Rotate\": rotate,\n    \"Posterize\": posterize,\n    \"PosterizeIncreasing\": posterize,\n    \"PosterizeOriginal\": posterize,\n    \"Solarize\": solarize,\n    \"SolarizeIncreasing\": solarize,\n    \"SolarizeAdd\": solarize_add,\n    \"Color\": color,\n    \"ColorIncreasing\": color,\n    \"Contrast\": contrast,\n    \"ContrastIncreasing\": contrast,\n    \"Brightness\": brightness,\n    \"BrightnessIncreasing\": brightness,\n    \"Sharpness\": sharpness,\n    \"SharpnessIncreasing\": sharpness,\n    \"ShearX\": shear_x,\n    \"ShearY\": shear_y,\n    \"TranslateX\": translate_x_abs,\n    \"TranslateY\": translate_y_abs,\n    \"TranslateXRel\": translate_x_rel,\n    \"TranslateYRel\": translate_y_rel,\n}\n\n\nclass AugmentOp:\n    \"\"\"\n    Apply for video.\n    \"\"\"\n\n    def __init__(self, name, prob=0.5, magnitude=10, hparams=None):\n        hparams = hparams or _HPARAMS_DEFAULT\n        self.aug_fn = NAME_TO_OP[name]\n        self.level_fn = LEVEL_TO_ARG[name]\n        self.prob = prob\n        self.magnitude = magnitude\n        self.hparams = hparams.copy()\n        self.kwargs = {\n            \"fillcolor\": hparams[\"img_mean\"]\n            if \"img_mean\" in hparams\n            else _FILL,\n            \"resample\": hparams[\"interpolation\"]\n            if \"interpolation\" in hparams\n            else _RANDOM_INTERPOLATION,\n        }\n\n        # If magnitude_std is > 0, we introduce some randomness\n        # in the usually fixed policy and sample magnitude from a normal distribution\n        # with mean `magnitude` and std-dev of `magnitude_std`.\n        # NOTE This is my own hack, being tested, not in papers or reference impls.\n        self.magnitude_std = self.hparams.get(\"magnitude_std\", 0)\n\n    def __call__(self, img_list):\n        if self.prob < 1.0 and random.random() > self.prob:\n            return img_list\n        magnitude = self.magnitude\n        if self.magnitude_std and self.magnitude_std > 0:\n            magnitude = random.gauss(magnitude, self.magnitude_std)\n        magnitude = min(_MAX_LEVEL, max(0, magnitude))  # clip to valid range\n        level_args = (\n            self.level_fn(magnitude, self.hparams)\n            if self.level_fn is not None\n            else ()\n        )\n\n        if isinstance(img_list, list):\n            return [\n                self.aug_fn(img, *level_args, **self.kwargs) for img in img_list\n            ]\n        else:\n            return self.aug_fn(img_list, *level_args, **self.kwargs)\n\n\n_RAND_TRANSFORMS = [\n    \"AutoContrast\",\n    \"Equalize\",\n    \"Invert\",\n    \"Rotate\",\n    \"Posterize\",\n    \"Solarize\",\n    \"SolarizeAdd\",\n    \"Color\",\n    \"Contrast\",\n    \"Brightness\",\n    \"Sharpness\",\n    \"ShearX\",\n    \"ShearY\",\n    \"TranslateXRel\",\n    \"TranslateYRel\",\n]\n\n\n_RAND_INCREASING_TRANSFORMS = [\n    \"AutoContrast\",\n    \"Equalize\",\n    \"Invert\",\n    \"Rotate\",\n    \"PosterizeIncreasing\",\n    \"SolarizeIncreasing\",\n    \"SolarizeAdd\",\n    \"ColorIncreasing\",\n    \"ContrastIncreasing\",\n    \"BrightnessIncreasing\",\n    \"SharpnessIncreasing\",\n    \"ShearX\",\n    \"ShearY\",\n    \"TranslateXRel\",\n    \"TranslateYRel\",\n]\n\n\n# These experimental weights are based loosely on the relative improvements mentioned in paper.\n# They may not result in increased performance, but could likely be tuned to so.\n_RAND_CHOICE_WEIGHTS_0 = {\n    \"Rotate\": 0.3,\n    \"ShearX\": 0.2,\n    \"ShearY\": 0.2,\n    \"TranslateXRel\": 0.1,\n    \"TranslateYRel\": 0.1,\n    \"Color\": 0.025,\n    \"Sharpness\": 0.025,\n    \"AutoContrast\": 0.025,\n    \"Solarize\": 0.005,\n    \"SolarizeAdd\": 0.005,\n    \"Contrast\": 0.005,\n    \"Brightness\": 0.005,\n    \"Equalize\": 0.005,\n    \"Posterize\": 0,\n    \"Invert\": 0,\n}\n\n\ndef _select_rand_weights(weight_idx=0, transforms=None):\n    transforms = transforms or _RAND_TRANSFORMS\n    assert weight_idx == 0  # only one set of weights currently\n    rand_weights = _RAND_CHOICE_WEIGHTS_0\n    probs = [rand_weights[k] for k in transforms]\n    probs /= np.sum(probs)\n    return probs\n\n\ndef rand_augment_ops(magnitude=10, hparams=None, transforms=None):\n    hparams = hparams or _HPARAMS_DEFAULT\n    transforms = transforms or _RAND_TRANSFORMS\n    return [\n        AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)\n        for name in transforms\n    ]\n\n\nclass RandAugment:\n    def __init__(self, ops, num_layers=2, choice_weights=None):\n        self.ops = ops\n        self.num_layers = num_layers\n        self.choice_weights = choice_weights\n\n    def __call__(self, img):\n        # no replacement when using weighted choice\n        ops = np.random.choice(\n            self.ops,\n            self.num_layers,\n            replace=self.choice_weights is None,\n            p=self.choice_weights,\n        )\n        for op in ops:\n            img = op(img)\n        return img\n\n\ndef rand_augment_transform(config_str, hparams):\n    \"\"\"\n    RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719\n\n    Create a RandAugment transform\n    :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by\n    dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining\n    sections, not order sepecific determine\n        'm' - integer magnitude of rand augment\n        'n' - integer num layers (number of transform ops selected per image)\n        'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)\n        'mstd' -  float std deviation of magnitude noise applied\n        'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)\n    Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5\n    'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2\n    :param hparams: Other hparams (kwargs) for the RandAugmentation scheme\n    :return: A PyTorch compatible Transform\n    \"\"\"\n    magnitude = _MAX_LEVEL  # default to _MAX_LEVEL for magnitude (currently 10)\n    num_layers = 2  # default to 2 ops per image\n    weight_idx = None  # default to no probability weights for op choice\n    transforms = _RAND_TRANSFORMS\n    config = config_str.split(\"-\")\n    assert config[0] == \"rand\"\n    config = config[1:]\n    for c in config:\n        cs = re.split(r\"(\\d.*)\", c)\n        if len(cs) < 2:\n            continue\n        key, val = cs[:2]\n        if key == \"mstd\":\n            # noise param injected via hparams for now\n            hparams.setdefault(\"magnitude_std\", float(val))\n        elif key == \"inc\":\n            if bool(val):\n                transforms = _RAND_INCREASING_TRANSFORMS\n        elif key == \"m\":\n            magnitude = int(val)\n        elif key == \"n\":\n            num_layers = int(val)\n        elif key == \"w\":\n            weight_idx = int(val)\n        else:\n            assert NotImplementedError\n    ra_ops = rand_augment_ops(\n        magnitude=magnitude, hparams=hparams, transforms=transforms\n    )\n    choice_weights = (\n        None if weight_idx is None else _select_rand_weights(weight_idx)\n    )\n    return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)\n"
  },
  {
    "path": "video_dataset/random_erasing.py",
    "content": "#!/usr/bin/env python\n# Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/random_erasing.py\n\n# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n\n\"\"\"\nThis implementation is based on\nhttps://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py\npulished under an Apache License 2.0.\n\nCOMMENT FROM ORIGINAL:\nOriginally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0\nCopyright Zhun Zhong & Liang Zheng\nHacked together by / Copyright 2020 Ross Wightman\n\"\"\"\nimport math\nimport random\nimport torch\n\n\ndef _get_pixels(\n    per_pixel, rand_color, patch_size, dtype=torch.float32, device=\"cuda\"\n):\n    # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()\n    # paths, flip the order so normal is run on CPU if this becomes a problem\n    # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508\n    if per_pixel:\n        return torch.empty(patch_size, dtype=dtype, device=device).normal_()\n    elif rand_color:\n        return torch.empty(\n            (patch_size[0], 1, 1), dtype=dtype, device=device\n        ).normal_()\n    else:\n        return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)\n\n\nclass RandomErasing:\n    \"\"\"Randomly selects a rectangle region in an image and erases its pixels.\n        'Random Erasing Data Augmentation' by Zhong et al.\n        See https://arxiv.org/pdf/1708.04896.pdf\n        This variant of RandomErasing is intended to be applied to either a batch\n        or single image tensor after it has been normalized by dataset mean and std.\n    Args:\n         probability: Probability that the Random Erasing operation will be performed.\n         min_area: Minimum percentage of erased area wrt input image area.\n         max_area: Maximum percentage of erased area wrt input image area.\n         min_aspect: Minimum aspect ratio of erased area.\n         mode: pixel color mode, one of 'const', 'rand', or 'pixel'\n            'const' - erase block is constant color of 0 for all channels\n            'rand'  - erase block is same per-channel random (normal) color\n            'pixel' - erase block is per-pixel random (normal) color\n        max_count: maximum number of erasing blocks per image, area per box is scaled by count.\n            per-image count is randomly chosen between 1 and this value.\n    \"\"\"\n\n    def __init__(\n        self,\n        probability=0.5,\n        min_area=0.02,\n        max_area=1 / 3,\n        min_aspect=0.3,\n        max_aspect=None,\n        mode=\"const\",\n        min_count=1,\n        max_count=None,\n        num_splits=0,\n        device=\"cuda\",\n        cube=True,\n    ):\n        self.probability = probability\n        self.min_area = min_area\n        self.max_area = max_area\n        max_aspect = max_aspect or 1 / min_aspect\n        self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))\n        self.min_count = min_count\n        self.max_count = max_count or min_count\n        self.num_splits = num_splits\n        mode = mode.lower()\n        self.rand_color = False\n        self.per_pixel = False\n        self.cube = cube\n        if mode == \"rand\":\n            self.rand_color = True  # per block random normal\n        elif mode == \"pixel\":\n            self.per_pixel = True  # per pixel random normal\n        else:\n            assert not mode or mode == \"const\"\n        self.device = device\n\n    def _erase(self, img, chan, img_h, img_w, dtype):\n        if random.random() > self.probability:\n            return\n        area = img_h * img_w\n        count = (\n            self.min_count\n            if self.min_count == self.max_count\n            else random.randint(self.min_count, self.max_count)\n        )\n        for _ in range(count):\n            for _ in range(10):\n                target_area = (\n                    random.uniform(self.min_area, self.max_area) * area / count\n                )\n                aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))\n                h = int(round(math.sqrt(target_area * aspect_ratio)))\n                w = int(round(math.sqrt(target_area / aspect_ratio)))\n                if w < img_w and h < img_h:\n                    top = random.randint(0, img_h - h)\n                    left = random.randint(0, img_w - w)\n                    img[:, top : top + h, left : left + w] = _get_pixels(\n                        self.per_pixel,\n                        self.rand_color,\n                        (chan, h, w),\n                        dtype=dtype,\n                        device=self.device,\n                    )\n                    break\n\n    def _erase_cube(\n        self,\n        img,\n        batch_start,\n        batch_size,\n        chan,\n        img_h,\n        img_w,\n        dtype,\n    ):\n        if random.random() > self.probability:\n            return\n        area = img_h * img_w\n        count = (\n            self.min_count\n            if self.min_count == self.max_count\n            else random.randint(self.min_count, self.max_count)\n        )\n        for _ in range(count):\n            for _ in range(100):\n                target_area = (\n                    random.uniform(self.min_area, self.max_area) * area / count\n                )\n                aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))\n                h = int(round(math.sqrt(target_area * aspect_ratio)))\n                w = int(round(math.sqrt(target_area / aspect_ratio)))\n                if w < img_w and h < img_h:\n                    top = random.randint(0, img_h - h)\n                    left = random.randint(0, img_w - w)\n                    for i in range(batch_start, batch_size):\n                        img_instance = img[i]\n                        img_instance[\n                            :, top : top + h, left : left + w\n                        ] = _get_pixels(\n                            self.per_pixel,\n                            self.rand_color,\n                            (chan, h, w),\n                            dtype=dtype,\n                            device=self.device,\n                        )\n                    break\n\n    def __call__(self, input):\n        if len(input.size()) == 3:\n            self._erase(input, *input.size(), input.dtype)\n        else:\n            batch_size, chan, img_h, img_w = input.size()\n            # skip first slice of batch if num_splits is set (for clean portion of samples)\n            batch_start = (\n                batch_size // self.num_splits if self.num_splits > 1 else 0\n            )\n            if self.cube:\n                self._erase_cube(\n                    input,\n                    batch_start,\n                    batch_size,\n                    chan,\n                    img_h,\n                    img_w,\n                    input.dtype,\n                )\n            else:\n                for i in range(batch_start, batch_size):\n                    self._erase(input[i], chan, img_h, img_w, input.dtype)\n        return input\n"
  },
  {
    "path": "video_dataset/transform.py",
    "content": "#!/usr/bin/env python3\n# Originate from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/transform.py\n# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n\nimport logging\nimport math\nimport numpy as np\n\n# import cv2\nimport random\nimport torch\nimport torchvision as tv\nimport torchvision.transforms.functional as F\nfrom PIL import Image, ImageFilter\nfrom torchvision import transforms\n\nfrom .rand_augment import rand_augment_transform\nfrom .random_erasing import RandomErasing\n\n_pil_interpolation_to_str = {\n    Image.NEAREST: \"PIL.Image.NEAREST\",\n    Image.BILINEAR: \"PIL.Image.BILINEAR\",\n    Image.BICUBIC: \"PIL.Image.BICUBIC\",\n    Image.LANCZOS: \"PIL.Image.LANCZOS\",\n    Image.HAMMING: \"PIL.Image.HAMMING\",\n    Image.BOX: \"PIL.Image.BOX\",\n}\n\n\n_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)\n\n\ndef _pil_interp(method):\n    if method == \"bicubic\":\n        return Image.BICUBIC\n    elif method == \"lanczos\":\n        return Image.LANCZOS\n    elif method == \"hamming\":\n        return Image.HAMMING\n    else:\n        return Image.BILINEAR\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef random_short_side_scale_jitter(\n    images, min_size, max_size, boxes=None, inverse_uniform_sampling=False\n):\n    \"\"\"\n    Perform a spatial short scale jittering on the given images and\n    corresponding boxes.\n    Args:\n        images (tensor): images to perform scale jitter. Dimension is\n            `num frames` x `channel` x `height` x `width`.\n        min_size (int): the minimal size to scale the frames.\n        max_size (int): the maximal size to scale the frames.\n        boxes (ndarray): optional. Corresponding boxes to images.\n            Dimension is `num boxes` x 4.\n        inverse_uniform_sampling (bool): if True, sample uniformly in\n            [1 / max_scale, 1 / min_scale] and take a reciprocal to get the\n            scale. If False, take a uniform sample from [min_scale, max_scale].\n    Returns:\n        (tensor): the scaled images with dimension of\n            `num frames` x `channel` x `new height` x `new width`.\n        (ndarray or None): the scaled boxes with dimension of\n            `num boxes` x 4.\n    \"\"\"\n    if inverse_uniform_sampling:\n        size = int(\n            round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))\n        )\n    else:\n        size = int(round(np.random.uniform(min_size, max_size)))\n\n    height = images.shape[2]\n    width = images.shape[3]\n    if (width <= height and width == size) or (\n        height <= width and height == size\n    ):\n        return images, boxes\n    new_width = size\n    new_height = size\n    if width < height:\n        new_height = int(math.floor((float(height) / width) * size))\n        if boxes is not None:\n            boxes = boxes * float(new_height) / height\n    else:\n        new_width = int(math.floor((float(width) / height) * size))\n        if boxes is not None:\n            boxes = boxes * float(new_width) / width\n\n    return (\n        torch.nn.functional.interpolate(\n            images,\n            size=(new_height, new_width),\n            mode=\"bilinear\",\n            align_corners=False,\n        ),\n        boxes,\n    )\n\n\ndef crop_boxes(boxes, x_offset, y_offset):\n    \"\"\"\n    Peform crop on the bounding boxes given the offsets.\n    Args:\n        boxes (ndarray or None): bounding boxes to peform crop. The dimension\n            is `num boxes` x 4.\n        x_offset (int): cropping offset in the x axis.\n        y_offset (int): cropping offset in the y axis.\n    Returns:\n        cropped_boxes (ndarray or None): the cropped boxes with dimension of\n            `num boxes` x 4.\n    \"\"\"\n    cropped_boxes = boxes.copy()\n    cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset\n    cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset\n\n    return cropped_boxes\n\n\ndef random_crop(images, size, boxes=None):\n    \"\"\"\n    Perform random spatial crop on the given images and corresponding boxes.\n    Args:\n        images (tensor): images to perform random crop. The dimension is\n            `num frames` x `channel` x `height` x `width`.\n        size (int): the size of height and width to crop on the image.\n        boxes (ndarray or None): optional. Corresponding boxes to images.\n            Dimension is `num boxes` x 4.\n    Returns:\n        cropped (tensor): cropped images with dimension of\n            `num frames` x `channel` x `size` x `size`.\n        cropped_boxes (ndarray or None): the cropped boxes with dimension of\n            `num boxes` x 4.\n    \"\"\"\n    if images.shape[2] == size and images.shape[3] == size:\n        return images, boxes\n    height = images.shape[2]\n    width = images.shape[3]\n    y_offset = 0\n    if height > size:\n        y_offset = int(np.random.randint(0, height - size))\n    x_offset = 0\n    if width > size:\n        x_offset = int(np.random.randint(0, width - size))\n    cropped = images[\n        :, :, y_offset : y_offset + size, x_offset : x_offset + size\n    ]\n\n    cropped_boxes = (\n        crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None\n    )\n\n    return cropped, cropped_boxes\n\n\ndef horizontal_flip(prob, images, boxes=None):\n    \"\"\"\n    Perform horizontal flip on the given images and corresponding boxes.\n    Args:\n        prob (float): probility to flip the images.\n        images (tensor): images to perform horizontal flip, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n        boxes (ndarray or None): optional. Corresponding boxes to images.\n            Dimension is `num boxes` x 4.\n    Returns:\n        images (tensor): images with dimension of\n            `num frames` x `channel` x `height` x `width`.\n        flipped_boxes (ndarray or None): the flipped boxes with dimension of\n            `num boxes` x 4.\n    \"\"\"\n    if boxes is None:\n        flipped_boxes = None\n    else:\n        flipped_boxes = boxes.copy()\n\n    if np.random.uniform() < prob:\n        images = images.flip((-1))\n\n        if len(images.shape) == 3:\n            width = images.shape[2]\n        elif len(images.shape) == 4:\n            width = images.shape[3]\n        else:\n            raise NotImplementedError(\"Dimension does not supported\")\n        if boxes is not None:\n            flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1\n\n    return images, flipped_boxes\n\n\ndef uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):\n    \"\"\"\n    Perform uniform spatial sampling on the images and corresponding boxes.\n    Args:\n        images (tensor): images to perform uniform crop. The dimension is\n            `num frames` x `channel` x `height` x `width`.\n        size (int): size of height and weight to crop the images.\n        spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width\n            is larger than height. Or 0, 1, or 2 for top, center, and bottom\n            crop if height is larger than width.\n        boxes (ndarray or None): optional. Corresponding boxes to images.\n            Dimension is `num boxes` x 4.\n        scale_size (int): optinal. If not None, resize the images to scale_size before\n            performing any crop.\n    Returns:\n        cropped (tensor): images with dimension of\n            `num frames` x `channel` x `size` x `size`.\n        cropped_boxes (ndarray or None): the cropped boxes with dimension of\n            `num boxes` x 4.\n    \"\"\"\n    assert spatial_idx in [0, 1, 2]\n    ndim = len(images.shape)\n    if ndim == 3:\n        images = images.unsqueeze(0)\n    height = images.shape[2]\n    width = images.shape[3]\n\n    if scale_size is not None:\n        if width <= height:\n            width, height = scale_size, int(height / width * scale_size)\n        else:\n            width, height = int(width / height * scale_size), scale_size\n        images = torch.nn.functional.interpolate(\n            images,\n            size=(height, width),\n            mode=\"bilinear\",\n            align_corners=False,\n        )\n\n    y_offset = int(math.ceil((height - size) / 2))\n    x_offset = int(math.ceil((width - size) / 2))\n\n    if height > width:\n        if spatial_idx == 0:\n            y_offset = 0\n        elif spatial_idx == 2:\n            y_offset = height - size\n    else:\n        if spatial_idx == 0:\n            x_offset = 0\n        elif spatial_idx == 2:\n            x_offset = width - size\n    cropped = images[\n        :, :, y_offset : y_offset + size, x_offset : x_offset + size\n    ]\n    cropped_boxes = (\n        crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None\n    )\n    if ndim == 3:\n        cropped = cropped.squeeze(0)\n    return cropped, cropped_boxes\n\n\ndef clip_boxes_to_image(boxes, height, width):\n    \"\"\"\n    Clip an array of boxes to an image with the given height and width.\n    Args:\n        boxes (ndarray): bounding boxes to perform clipping.\n            Dimension is `num boxes` x 4.\n        height (int): given image height.\n        width (int): given image width.\n    Returns:\n        clipped_boxes (ndarray): the clipped boxes with dimension of\n            `num boxes` x 4.\n    \"\"\"\n    clipped_boxes = boxes.copy()\n    clipped_boxes[:, [0, 2]] = np.minimum(\n        width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])\n    )\n    clipped_boxes[:, [1, 3]] = np.minimum(\n        height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])\n    )\n    return clipped_boxes\n\n\ndef blend(images1, images2, alpha):\n    \"\"\"\n    Blend two images with a given weight alpha.\n    Args:\n        images1 (tensor): the first images to be blended, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n        images2 (tensor): the second images to be blended, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n        alpha (float): the blending weight.\n    Returns:\n        (tensor): blended images, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n    \"\"\"\n    return images1 * alpha + images2 * (1 - alpha)\n\n\ndef grayscale(images):\n    \"\"\"\n    Get the grayscale for the input images. The channels of images should be\n    in order BGR.\n    Args:\n        images (tensor): the input images for getting grayscale. Dimension is\n            `num frames` x `channel` x `height` x `width`.\n    Returns:\n        img_gray (tensor): blended images, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n    \"\"\"\n    # R -> 0.299, G -> 0.587, B -> 0.114.\n    img_gray = torch.tensor(images)\n    gray_channel = (\n        0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]\n    )\n    img_gray[:, 0] = gray_channel\n    img_gray[:, 1] = gray_channel\n    img_gray[:, 2] = gray_channel\n    return img_gray\n\n\ndef color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):\n    \"\"\"\n    Perfrom a color jittering on the input images. The channels of images\n    should be in order BGR.\n    Args:\n        images (tensor): images to perform color jitter. Dimension is\n            `num frames` x `channel` x `height` x `width`.\n        img_brightness (float): jitter ratio for brightness.\n        img_contrast (float): jitter ratio for contrast.\n        img_saturation (float): jitter ratio for saturation.\n    Returns:\n        images (tensor): the jittered images, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n    \"\"\"\n\n    jitter = []\n    if img_brightness != 0:\n        jitter.append(\"brightness\")\n    if img_contrast != 0:\n        jitter.append(\"contrast\")\n    if img_saturation != 0:\n        jitter.append(\"saturation\")\n\n    if len(jitter) > 0:\n        order = np.random.permutation(np.arange(len(jitter)))\n        for idx in range(0, len(jitter)):\n            if jitter[order[idx]] == \"brightness\":\n                images = brightness_jitter(img_brightness, images)\n            elif jitter[order[idx]] == \"contrast\":\n                images = contrast_jitter(img_contrast, images)\n            elif jitter[order[idx]] == \"saturation\":\n                images = saturation_jitter(img_saturation, images)\n    return images\n\n\ndef brightness_jitter(var, images):\n    \"\"\"\n    Perfrom brightness jittering on the input images. The channels of images\n    should be in order BGR.\n    Args:\n        var (float): jitter ratio for brightness.\n        images (tensor): images to perform color jitter. Dimension is\n            `num frames` x `channel` x `height` x `width`.\n    Returns:\n        images (tensor): the jittered images, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n    \"\"\"\n    alpha = 1.0 + np.random.uniform(-var, var)\n\n    img_bright = torch.zeros(images.shape)\n    images = blend(images, img_bright, alpha)\n    return images\n\n\ndef contrast_jitter(var, images):\n    \"\"\"\n    Perfrom contrast jittering on the input images. The channels of images\n    should be in order BGR.\n    Args:\n        var (float): jitter ratio for contrast.\n        images (tensor): images to perform color jitter. Dimension is\n            `num frames` x `channel` x `height` x `width`.\n    Returns:\n        images (tensor): the jittered images, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n    \"\"\"\n    alpha = 1.0 + np.random.uniform(-var, var)\n\n    img_gray = grayscale(images)\n    img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)\n    images = blend(images, img_gray, alpha)\n    return images\n\n\ndef saturation_jitter(var, images):\n    \"\"\"\n    Perfrom saturation jittering on the input images. The channels of images\n    should be in order BGR.\n    Args:\n        var (float): jitter ratio for saturation.\n        images (tensor): images to perform color jitter. Dimension is\n            `num frames` x `channel` x `height` x `width`.\n    Returns:\n        images (tensor): the jittered images, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n    \"\"\"\n    alpha = 1.0 + np.random.uniform(-var, var)\n    img_gray = grayscale(images)\n    images = blend(images, img_gray, alpha)\n\n    return images\n\n\ndef lighting_jitter(images, alphastd, eigval, eigvec):\n    \"\"\"\n    Perform AlexNet-style PCA jitter on the given images.\n    Args:\n        images (tensor): images to perform lighting jitter. Dimension is\n            `num frames` x `channel` x `height` x `width`.\n        alphastd (float): jitter ratio for PCA jitter.\n        eigval (list): eigenvalues for PCA jitter.\n        eigvec (list[list]): eigenvectors for PCA jitter.\n    Returns:\n        out_images (tensor): the jittered images, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n    \"\"\"\n    if alphastd == 0:\n        return images\n    # generate alpha1, alpha2, alpha3.\n    alpha = np.random.normal(0, alphastd, size=(1, 3))\n    eig_vec = np.array(eigvec)\n    eig_val = np.reshape(eigval, (1, 3))\n    rgb = np.sum(\n        eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),\n        axis=1,\n    )\n    out_images = torch.zeros_like(images)\n    if len(images.shape) == 3:\n        # C H W\n        channel_dim = 0\n    elif len(images.shape) == 4:\n        # T C H W\n        channel_dim = 1\n    else:\n        raise NotImplementedError(f\"Unsupported dimension {len(images.shape)}\")\n\n    for idx in range(images.shape[channel_dim]):\n        # C H W\n        if len(images.shape) == 3:\n            out_images[idx] = images[idx] + rgb[2 - idx]\n        # T C H W\n        elif len(images.shape) == 4:\n            out_images[:, idx] = images[:, idx] + rgb[2 - idx]\n        else:\n            raise NotImplementedError(\n                f\"Unsupported dimension {len(images.shape)}\"\n            )\n\n    return out_images\n\n\ndef color_normalization(images, mean, stddev):\n    \"\"\"\n    Perform color nomration on the given images.\n    Args:\n        images (tensor): images to perform color normalization. Dimension is\n            `num frames` x `channel` x `height` x `width`.\n        mean (list): mean values for normalization.\n        stddev (list): standard deviations for normalization.\n\n    Returns:\n        out_images (tensor): the noramlized images, the dimension is\n            `num frames` x `channel` x `height` x `width`.\n    \"\"\"\n    if len(images.shape) == 3:\n        assert (\n            len(mean) == images.shape[0]\n        ), \"channel mean not computed properly\"\n        assert (\n            len(stddev) == images.shape[0]\n        ), \"channel stddev not computed properly\"\n    elif len(images.shape) == 4:\n        assert (\n            len(mean) == images.shape[1]\n        ), \"channel mean not computed properly\"\n        assert (\n            len(stddev) == images.shape[1]\n        ), \"channel stddev not computed properly\"\n    else:\n        raise NotImplementedError(f\"Unsupported dimension {len(images.shape)}\")\n\n    out_images = torch.zeros_like(images)\n    for idx in range(len(mean)):\n        # C H W\n        if len(images.shape) == 3:\n            out_images[idx] = (images[idx] - mean[idx]) / stddev[idx]\n        elif len(images.shape) == 4:\n            out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]\n        else:\n            raise NotImplementedError(\n                f\"Unsupported dimension {len(images.shape)}\"\n            )\n    return out_images\n\n\ndef _get_param_spatial_crop(\n    scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False\n):\n    \"\"\"\n    Given scale, ratio, height and width, return sampled coordinates of the videos.\n    \"\"\"\n    for _ in range(num_repeat):\n        area = height * width\n        target_area = random.uniform(*scale) * area\n        if log_scale:\n            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))\n            aspect_ratio = math.exp(random.uniform(*log_ratio))\n        else:\n            aspect_ratio = random.uniform(*ratio)\n\n        w = int(round(math.sqrt(target_area * aspect_ratio)))\n        h = int(round(math.sqrt(target_area / aspect_ratio)))\n\n        if np.random.uniform() < 0.5 and switch_hw:\n            w, h = h, w\n\n        if 0 < w <= width and 0 < h <= height:\n            i = random.randint(0, height - h)\n            j = random.randint(0, width - w)\n            return i, j, h, w\n\n    # Fallback to central crop\n    in_ratio = float(width) / float(height)\n    if in_ratio < min(ratio):\n        w = width\n        h = int(round(w / min(ratio)))\n    elif in_ratio > max(ratio):\n        h = height\n        w = int(round(h * max(ratio)))\n    else:  # whole image\n        w = width\n        h = height\n    i = (height - h) // 2\n    j = (width - w) // 2\n    return i, j, h, w\n\n\ndef random_resized_crop(\n    images,\n    target_height,\n    target_width,\n    scale=(0.08, 1.0),\n    ratio=(3.0 / 4.0, 4.0 / 3.0),\n):\n    \"\"\"\n    Crop the given images to random size and aspect ratio. A crop of random\n    size (default: of 0.08 to 1.0) of the original size and a random aspect\n    ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This\n    crop is finally resized to given size. This is popularly used to train the\n    Inception networks.\n\n    Args:\n        images: Images to perform resizing and cropping.\n        target_height: Desired height after cropping.\n        target_width: Desired width after cropping.\n        scale: Scale range of Inception-style area based random resizing.\n        ratio: Aspect ratio range of Inception-style area based random resizing.\n    \"\"\"\n\n    height = images.shape[2]\n    width = images.shape[3]\n\n    i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)\n    cropped = images[:, :, i : i + h, j : j + w]\n    return torch.nn.functional.interpolate(\n        cropped,\n        size=(target_height, target_width),\n        mode=\"bilinear\",\n        align_corners=False,\n    )\n\n\ndef random_resized_crop_with_shift(\n    images,\n    target_height,\n    target_width,\n    scale=(0.8, 1.0),\n    ratio=(3.0 / 4.0, 4.0 / 3.0),\n):\n    \"\"\"\n    This is similar to random_resized_crop. However, it samples two different\n    boxes (for cropping) for the first and last frame. It then linearly\n    interpolates the two boxes for other frames.\n\n    Args:\n        images: Images to perform resizing and cropping.\n        target_height: Desired height after cropping.\n        target_width: Desired width after cropping.\n        scale: Scale range of Inception-style area based random resizing.\n        ratio: Aspect ratio range of Inception-style area based random resizing.\n    \"\"\"\n    t = images.shape[1]\n    height = images.shape[2]\n    width = images.shape[3]\n\n    i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)\n    i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width)\n    i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]\n    j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]\n    h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]\n    w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]\n    out = torch.zeros((3, t, target_height, target_width))\n    for ind in range(t):\n        out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(\n            images[\n                :,\n                ind : ind + 1,\n                i_s[ind] : i_s[ind] + h_s[ind],\n                j_s[ind] : j_s[ind] + w_s[ind],\n            ],\n            size=(target_height, target_width),\n            mode=\"bilinear\",\n            align_corners=False,\n        )\n    return out\n\n\ndef create_random_augment(\n    input_size,\n    auto_augment=None,\n    interpolation=\"bilinear\",\n):\n    \"\"\"\n    Get video randaug transform.\n\n    Args:\n        input_size: The size of the input video in tuple.\n        auto_augment: Parameters for randaug. An example:\n            \"rand-m7-n4-mstd0.5-inc1\" (m is the magnitude and n is the number\n            of operations to apply).\n        interpolation: Interpolation method.\n    \"\"\"\n    if isinstance(input_size, tuple):\n        img_size = input_size[-2:]\n    else:\n        img_size = input_size\n\n    if auto_augment:\n        assert isinstance(auto_augment, str)\n        if isinstance(img_size, tuple):\n            img_size_min = min(img_size)\n        else:\n            img_size_min = img_size\n        aa_params = {\"translate_const\": int(img_size_min * 0.45)}\n        if interpolation and interpolation != \"random\":\n            aa_params[\"interpolation\"] = _pil_interp(interpolation)\n        if auto_augment.startswith(\"rand\"):\n            return transforms.Compose(\n                [rand_augment_transform(auto_augment, aa_params)]\n            )\n    raise NotImplementedError\n\n\ndef random_sized_crop_img(\n    im,\n    size,\n    jitter_scale=(0.08, 1.0),\n    jitter_aspect=(3.0 / 4.0, 4.0 / 3.0),\n    max_iter=10,\n):\n    \"\"\"\n    Performs Inception-style cropping (used for training).\n    \"\"\"\n    assert (\n        len(im.shape) == 3\n    ), \"Currently only support image for random_sized_crop\"\n    h, w = im.shape[1:3]\n    i, j, h, w = _get_param_spatial_crop(\n        scale=jitter_scale,\n        ratio=jitter_aspect,\n        height=h,\n        width=w,\n        num_repeat=max_iter,\n        log_scale=False,\n        switch_hw=True,\n    )\n    cropped = im[:, i : i + h, j : j + w]\n    return torch.nn.functional.interpolate(\n        cropped.unsqueeze(0),\n        size=(size, size),\n        mode=\"bilinear\",\n        align_corners=False,\n    ).squeeze(0)\n\n\n# The following code are modified based on timm lib, we will replace the following\n# contents with dependency from PyTorchVideo.\n# https://github.com/facebookresearch/pytorchvideo\nclass RandomResizedCropAndInterpolation:\n    \"\"\"Crop the given PIL Image to random size and aspect ratio with random interpolation.\n    A crop of random size (default: of 0.08 to 1.0) of the original size and a random\n    aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop\n    is finally resized to given size.\n    This is popularly used to train the Inception networks.\n    Args:\n        size: expected output size of each edge\n        scale: range of size of the origin size cropped\n        ratio: range of aspect ratio of the origin aspect ratio cropped\n        interpolation: Default: PIL.Image.BILINEAR\n    \"\"\"\n\n    def __init__(\n        self,\n        size,\n        scale=(0.08, 1.0),\n        ratio=(3.0 / 4.0, 4.0 / 3.0),\n        interpolation=\"bilinear\",\n    ):\n        if isinstance(size, tuple):\n            self.size = size\n        else:\n            self.size = (size, size)\n        if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):\n            print(\"range should be of kind (min, max)\")\n\n        if interpolation == \"random\":\n            self.interpolation = _RANDOM_INTERPOLATION\n        else:\n            self.interpolation = _pil_interp(interpolation)\n        self.scale = scale\n        self.ratio = ratio\n\n    @staticmethod\n    def get_params(img, scale, ratio):\n        \"\"\"Get parameters for ``crop`` for a random sized crop.\n        Args:\n            img (PIL Image): Image to be cropped.\n            scale (tuple): range of size of the origin size cropped\n            ratio (tuple): range of aspect ratio of the origin aspect ratio cropped\n        Returns:\n            tuple: params (i, j, h, w) to be passed to ``crop`` for a random\n                sized crop.\n        \"\"\"\n        area = img.size[0] * img.size[1]\n\n        for _ in range(10):\n            target_area = random.uniform(*scale) * area\n            log_ratio = (math.log(ratio[0]), math.log(ratio[1]))\n            aspect_ratio = math.exp(random.uniform(*log_ratio))\n\n            w = int(round(math.sqrt(target_area * aspect_ratio)))\n            h = int(round(math.sqrt(target_area / aspect_ratio)))\n\n            if w <= img.size[0] and h <= img.size[1]:\n                i = random.randint(0, img.size[1] - h)\n                j = random.randint(0, img.size[0] - w)\n                return i, j, h, w\n\n        # Fallback to central crop\n        in_ratio = img.size[0] / img.size[1]\n        if in_ratio < min(ratio):\n            w = img.size[0]\n            h = int(round(w / min(ratio)))\n        elif in_ratio > max(ratio):\n            h = img.size[1]\n            w = int(round(h * max(ratio)))\n        else:  # whole image\n            w = img.size[0]\n            h = img.size[1]\n        i = (img.size[1] - h) // 2\n        j = (img.size[0] - w) // 2\n        return i, j, h, w\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (PIL Image): Image to be cropped and resized.\n        Returns:\n            PIL Image: Randomly cropped and resized image.\n        \"\"\"\n        i, j, h, w = self.get_params(img, self.scale, self.ratio)\n        if isinstance(self.interpolation, (tuple, list)):\n            interpolation = random.choice(self.interpolation)\n        else:\n            interpolation = self.interpolation\n        return F.resized_crop(img, i, j, h, w, self.size, interpolation)\n\n    def __repr__(self):\n        if isinstance(self.interpolation, (tuple, list)):\n            interpolate_str = \" \".join(\n                [_pil_interpolation_to_str[x] for x in self.interpolation]\n            )\n        else:\n            interpolate_str = _pil_interpolation_to_str[self.interpolation]\n        format_string = self.__class__.__name__ + \"(size={0}\".format(self.size)\n        format_string += \", scale={0}\".format(\n            tuple(round(s, 4) for s in self.scale)\n        )\n        format_string += \", ratio={0}\".format(\n            tuple(round(r, 4) for r in self.ratio)\n        )\n        format_string += \", interpolation={0})\".format(interpolate_str)\n        return format_string\n"
  },
  {
    "path": "vision_transformer.py",
    "content": "#!/usr/bin/env python\n\nfrom collections import OrderedDict\nimport numpy as np\nfrom typing import Tuple\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n'''\nQuickGELU and LayerNorm w/ fp16 from official CLIP repo\n(https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py)\n'''\nclass QuickGELU(nn.Module):\n    def forward(self, x: torch.Tensor):\n        return x * torch.sigmoid(1.702 * x)\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16.\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        ret = super().forward(x.type(torch.float32))\n        return ret.type(orig_type)\n\n\nclass Attention(nn.Module):\n    '''\n    A generalized attention module with more flexibility.\n    '''\n\n    def __init__(\n        self, q_in_dim: int, k_in_dim: int, v_in_dim: int,\n        qk_proj_dim: int, v_proj_dim: int, num_heads: int, out_dim: int,\n        return_all_features: bool = False,\n    ):\n        super().__init__()\n\n        self.q_proj = nn.Linear(q_in_dim, qk_proj_dim)\n        self.k_proj = nn.Linear(k_in_dim, qk_proj_dim)\n        self.v_proj = nn.Linear(v_in_dim, v_proj_dim)\n        self.out_proj = nn.Linear(v_proj_dim, out_dim)\n\n        self.num_heads = num_heads\n        self.return_all_features = return_all_features\n        assert qk_proj_dim % num_heads == 0 and v_proj_dim % num_heads == 0\n\n        self._initialize_weights()\n\n\n    def _initialize_weights(self):\n        for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):\n            nn.init.xavier_uniform_(m.weight)\n            nn.init.constant_(m.bias, 0.)\n\n\n    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):\n        assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3\n        N = q.size(0); assert k.size(0) == N and v.size(0) == N\n        Lq, Lkv = q.size(1), k.size(1); assert v.size(1) == Lkv\n\n        q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v)\n        \n        H = self.num_heads\n        Cqk, Cv = q.size(-1) // H, v.size(-1) // H\n\n        q = q.view(N, Lq, H, Cqk)\n        k = k.view(N, Lkv, H, Cqk)\n        v = v.view(N, Lkv, H, Cv)\n\n        aff = torch.einsum('nqhc,nkhc->nqkh', q / (Cqk ** 0.5), k)\n        aff = aff.softmax(dim=-2)\n        mix = torch.einsum('nqlh,nlhc->nqhc', aff, v)\n\n        out = self.out_proj(mix.flatten(-2))\n\n        if self.return_all_features:\n            return dict(q=q, k=k, v=v, aff=aff, out=out)\n        else:\n            return out\n\n\nclass PatchEmbed2D(nn.Module):\n\n    def __init__(\n        self,\n        patch_size: Tuple[int, int] = (16, 16),\n        in_channels: int = 3,\n        embed_dim: int = 768,\n    ):\n        super().__init__()\n\n        self.patch_size = patch_size\n        self.in_channels = in_channels\n\n        self.proj = nn.Linear(np.prod(patch_size) * in_channels, embed_dim)\n\n\n    def _initialize_weights(self, x):\n        nn.init.kaiming_normal_(self.proj.weight, 0.)\n        nn.init.constant_(self.proj.bias, 0.)\n\n\n    def forward(self, x: torch.Tensor):\n        B, C, H, W = x.size()\n        pH, pW = self.patch_size\n\n        assert C == self.in_channels and H % pH == 0 and W % pW == 0\n\n        x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3).flatten(1, 2)\n        x = self.proj(x)\n        \n        return x\n\nclass TransformerEncoderLayer(nn.Module):\n\n    def __init__(\n        self,\n        in_feature_dim: int = 768,\n        qkv_dim: int = 768,\n        num_heads: int = 12,\n        mlp_factor: float = 4.0,\n        mlp_dropout: float = 0.0,\n        act: nn.Module = QuickGELU,\n        return_all_features: bool = False,\n    ):\n        super().__init__()\n\n        self.return_all_features = return_all_features\n\n        self.attn = Attention(\n            q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,\n            qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim,\n            return_all_features=return_all_features,\n        )\n\n        mlp_dim = round(mlp_factor * in_feature_dim)\n        self.mlp = nn.Sequential(OrderedDict([\n            ('fc1', nn.Linear(in_feature_dim, mlp_dim)),\n            ('act', act()),\n            ('dropout', nn.Dropout(mlp_dropout)),\n            ('fc2', nn.Linear(mlp_dim, in_feature_dim)),\n        ]))\n\n        self.norm1 = LayerNorm(in_feature_dim)\n        self.norm2 = LayerNorm(in_feature_dim)\n\n        self._initialize_weights()\n\n\n    def _initialize_weights(self):\n        for m in (self.mlp[0], self.mlp[-1]):\n            nn.init.xavier_uniform_(m.weight)\n            nn.init.normal_(m.bias, std=1e-6)\n\n\n    def forward(self, x: torch.Tensor):\n        if self.return_all_features:\n            ret_dict = {}\n            \n            x_norm = self.norm1(x)\n            attn_out = self.attn(x_norm, x_norm, x_norm)\n            ret_dict['q'] = attn_out['q']\n            ret_dict['k'] = attn_out['k']\n            ret_dict['v'] = attn_out['v']\n            ret_dict['attn_out'] = attn_out['out']\n            x = x + attn_out['out']\n\n            x = x + self.mlp(self.norm2(x))\n            ret_dict['out'] = x\n\n            return ret_dict\n        \n        else:\n            x_norm = self.norm1(x)\n            x = x + self.attn(x_norm, x_norm, x_norm)\n            x = x + self.mlp(self.norm2(x))\n\n            return x\n\n\nclass TransformerDecoderLayer(nn.Module):\n\n    def __init__(\n        self,\n        in_feature_dim: int = 768,\n        qkv_dim: int = 768,\n        num_heads: int = 12,\n        mlp_factor: float = 4.0,\n        mlp_dropout: float = 0.0,\n        act: nn.Module = QuickGELU,\n    ):\n        super().__init__()\n\n        self.attn = Attention(\n            q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim,\n            qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim,\n        )\n\n        mlp_dim = round(mlp_factor * in_feature_dim)\n        self.mlp = nn.Sequential(OrderedDict([\n            ('fc1', nn.Linear(in_feature_dim, mlp_dim)),\n            ('act', act()),\n            ('dropout', nn.Dropout(mlp_dropout)),\n            ('fc2', nn.Linear(mlp_dim, in_feature_dim)),\n        ]))\n\n        self.norm1 = LayerNorm(in_feature_dim)\n        self.norm2 = LayerNorm(in_feature_dim)\n        self.norm3 = LayerNorm(in_feature_dim)\n\n        self._initialize_weights()\n\n\n    def _initialize_weights(self):\n        for m in (self.mlp[0], self.mlp[-1]):\n            nn.init.xavier_uniform_(m.weight)\n            nn.init.normal_(m.bias, std=1e-6)\n\n\n    def forward(self, x: torch.Tensor, y: torch.Tensor):\n        y_norm = self.norm3(y)\n        x = x + self.attn(self.norm1(x), y_norm, y_norm)\n        x = x + self.mlp(self.norm2(x))\n\n        return x\n\n\nclass VisionTransformer2D(nn.Module):\n\n    def __init__(\n        self,\n        feature_dim: int = 768,\n        input_size: Tuple[int, int] = (224, 224),\n        patch_size: Tuple[int, int] = (16, 16),\n        num_heads: int = 12,\n        num_layers: int = 12,\n        mlp_factor: float = 4.0,\n        act: nn.Module = QuickGELU,\n        return_all_features: bool = False,\n        ln_pre: bool = False,\n    ):\n        super().__init__()\n\n        self.return_all_features = return_all_features\n        \n        self.patch_embed = PatchEmbed2D(patch_size=patch_size, embed_dim=feature_dim)\n        self.num_patches = np.prod([x // y for x, y in zip(input_size, patch_size)]) + 1\n\n        self.cls_token = nn.Parameter(torch.zeros([feature_dim]))\n        self.pos_embed = nn.Parameter(torch.zeros([self.num_patches, feature_dim]))\n\n        self.blocks = nn.ModuleList([\n            TransformerEncoderLayer(\n                in_feature_dim=feature_dim, qkv_dim=feature_dim, num_heads=num_heads, mlp_factor=mlp_factor, act=act,\n                return_all_features=return_all_features,\n            ) for _ in range(num_layers)\n        ])\n\n        if ln_pre:\n            self.ln_pre = LayerNorm(feature_dim)\n        else:\n            self.ln_pre = nn.Identity()\n\n        self._initialize_weights()\n\n\n    def _initialize_weights(self):\n        nn.init.normal_(self.cls_token, std=0.02)\n        nn.init.normal_(self.pos_embed, std=0.02)\n\n    def forward(self, x: torch.Tensor):\n        dtype = self.patch_embed.proj.weight.dtype\n        x = x.to(dtype)\n\n        x = self.patch_embed(x)\n        x = torch.cat([self.cls_token.view(1, 1, -1).repeat(x.size(0), 1, 1), x], dim=1)\n        x = x + self.pos_embed\n\n        x = self.ln_pre(x)\n\n        if self.return_all_features:\n            all_features = []\n            for blk in self.blocks:\n                x = blk(x)\n                all_features.append(x)\n                x = x['out']\n            return all_features\n        \n        else:\n            for blk in self.blocks:\n                x = blk(x)\n            return x\n\n\ndef model_to_fp16(model: VisionTransformer2D):\n    def _module_to_fp16(m: nn.Module):\n        if isinstance(m, (nn.Linear,)):\n            m.half()\n    model.apply(_module_to_fp16)\n\n    model.pos_embed.data = model.pos_embed.data.half()\n    model.cls_token.data = model.cls_token.data.half()\n\n\nvit_presets = {\n    'ViT-B/16-lnpre': dict(\n        feature_dim=768,\n        input_size=(224, 224),\n        patch_size=(16, 16),\n        num_heads=12,\n        num_layers=12,\n        mlp_factor=4.0,\n        ln_pre=True,\n    ),\n    'ViT-L/14-lnpre': dict(\n        feature_dim=1024,\n        input_size=(224, 224),\n        patch_size=(14, 14),\n        num_heads=16,\n        num_layers=24,\n        mlp_factor=4.0,\n        ln_pre=True,\n    ),\n}"
  },
  {
    "path": "weight_loaders.py",
    "content": "#!/usr/bin/env python\n\nimport os, sys\nfrom typing import Dict\n\nimport torch\n\n__all__ = ['weight_loader_fn_dict']\n\ndef load_weights_clip(load_path: str) -> Dict[str, torch.Tensor]:\n    clip_model = torch.jit.load(load_path, map_location='cpu')\n    clip_model = clip_model.visual\n    src_state_dict = clip_model.state_dict()\n    src_state_dict = dict((k, v.float()) for k, v in src_state_dict.items())\n\n    dst_state_dict = {}\n    \n    dst_state_dict['cls_token'] = src_state_dict['class_embedding']\n    dst_state_dict['pos_embed'] = src_state_dict['positional_embedding']\n    dst_state_dict['patch_embed.proj.weight'] = src_state_dict['conv1.weight'].flatten(1)\n    dst_state_dict['patch_embed.proj.bias'] = torch.zeros([src_state_dict['conv1.weight'].size(0)])\n    \n    dst_state_dict['ln_pre.weight'] = src_state_dict['ln_pre.weight']\n    dst_state_dict['ln_pre.bias'] = src_state_dict['ln_pre.bias']\n\n    block_idx = 0\n    while True:\n        src_prefix = 'transformer.resblocks.%d.' % block_idx\n        dst_prefix = 'blocks.%d.' % block_idx\n\n        src_block_state_dict = dict((k[len(src_prefix):], v) for k, v in src_state_dict.items() if k.startswith(src_prefix))\n        if len(src_block_state_dict) == 0:\n            break\n\n        dst_block_state_dict = {}\n        feat_dim = src_block_state_dict['ln_1.weight'].size(0)\n\n        for i, dst_name in enumerate(('q', 'k', 'v')):\n            dst_block_state_dict['attn.%s_proj.weight' % dst_name] = src_block_state_dict['attn.in_proj_weight'][feat_dim * i: feat_dim * (i + 1)]\n            dst_block_state_dict['attn.%s_proj.bias' % dst_name] = src_block_state_dict['attn.in_proj_bias'][feat_dim * i: feat_dim * (i + 1)]\n        \n        dst_block_state_dict['attn.out_proj.weight'] = src_block_state_dict['attn.out_proj.weight']\n        dst_block_state_dict['attn.out_proj.bias'] = src_block_state_dict['attn.out_proj.bias']\n\n        dst_block_state_dict['mlp.fc1.weight'] = src_block_state_dict['mlp.c_fc.weight']\n        dst_block_state_dict['mlp.fc1.bias'] = src_block_state_dict['mlp.c_fc.bias']\n        dst_block_state_dict['mlp.fc2.weight'] = src_block_state_dict['mlp.c_proj.weight']\n        dst_block_state_dict['mlp.fc2.bias'] = src_block_state_dict['mlp.c_proj.bias']\n\n        dst_block_state_dict['norm1.weight'] = src_block_state_dict['ln_1.weight']\n        dst_block_state_dict['norm1.bias'] = src_block_state_dict['ln_1.bias']\n        dst_block_state_dict['norm2.weight'] = src_block_state_dict['ln_2.weight']\n        dst_block_state_dict['norm2.bias'] = src_block_state_dict['ln_2.bias']\n\n        dst_state_dict.update(dict((dst_prefix + k, v) for k, v in dst_block_state_dict.items()))\n        block_idx += 1\n\n    return dst_state_dict\n\n\nweight_loader_fn_dict = {\n    'clip': load_weights_clip,\n}\n"
  }
]