Repository: OpenGVLab/efficient-video-recognition Branch: master Commit: b282e8b49e4e Files: 26 Total size: 120.7 KB Directory structure: gitextract_7dyo2vh1/ ├── .gitignore ├── README.md ├── checkpoint.py ├── data/ │ └── k400_class_mappings.json ├── main.py ├── model.py ├── scripts/ │ ├── eval_k400_vitb16_16f_dec4x768.sh │ ├── eval_k400_vitb16_32f_dec4x768.sh │ ├── eval_k400_vitb16_8f_dec4x768.sh │ ├── eval_k400_vitl14_16f_dec4x1024.sh │ ├── eval_k400_vitl14_32f_dec4x1024.sh │ ├── eval_k400_vitl14_8f_dec4x1024.sh │ ├── train_k400_vitb16_16f_dec4x768.sh │ ├── train_k400_vitb16_32f_dec4x768.sh │ ├── train_k400_vitb16_8f_dec4x768.sh │ ├── train_k400_vitl14_16f_dec4x1024.sh │ ├── train_k400_vitl14_32f_dec4x1024.sh │ └── train_k400_vitl14_8f_dec4x1024.sh ├── video_dataset/ │ ├── __init__.py │ ├── dataloader.py │ ├── dataset.py │ ├── rand_augment.py │ ├── random_erasing.py │ └── transform.py ├── vision_transformer.py └── weight_loaders.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__ *.py[cod] *$py.class runs/ video_dataset/io_internal.py ================================================ FILE: README.md ================================================ # Frozen CLIP models are Efficient Video Learners This is the official implementation of the paper [Frozen CLIP models are Efficient Video Learners](https://arxiv.org/abs/2208.03550) ``` @article{lin2022frozen, title={Frozen CLIP Models are Efficient Video Learners}, author={Lin, Ziyi and Geng, Shijie and Zhang, Renrui and Gao, Peng and de Melo, Gerard and Wang, Xiaogang and Dai, Jifeng and Qiao, Yu and Li, Hongsheng}, journal={arXiv preprint arXiv:2208.03550}, year={2022} } ``` ## Introduction The overall architecture of the EVL framework includes a trainable Transformer decoder, trainable local temporal modules and a pretrained, fixed image backbone (CLIP is used for instance). Using a fixed backbone significantly saves training time, and we managed to train a ViT-B/16 with 8 frames for 50 epochs in 60 GPU-hours (NVIDIA V100). Despite with a small training computation and memory consumption, EVL models achieves high performance on Kinetics-400. A comparison with state-of-the-art methods are as follows ## Installation We tested the released code with the following conda environment ``` conda create -n pt1.9.0cu11.1_official -c pytorch -c conda-forge pytorch=1.9.0=py3.9_cuda11.1_cudnn8.0.5_0 cudatoolkit torchvision av ``` ## Data Preparation We expect that `--train_list_path` and `--val_list_path` command line arguments to be a data list file of the following format ``` ... ``` where `` points to a video file, and `` is an integer between `0` and `num_classes - 1`. `--num_classes` should also be specified in the command line argument. Additionally, `` might be a relative path when `--data_root` is specified, and the actual path will be relative to the path passed as `--data_root`. The class mappings in the open-source weights are provided at [Kinetics-400 class mappings](data/k400_class_mappings.json) ## Backbone Preparation CLIP weights need to be downloaded from [CLIP official repo](https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/clip.py#L30) and passed to the `--backbone_path` command line argument. ## Script Usage Training and evaliation scripts are provided in the scripts folder. Scripts should be ready to run once the environment is setup and `--backbone_path`, `--train_list_path` and `--val_list_path` are replaced with your own paths. For other command line arguments please see the help message for usage. ## Kinetics-400 Main Results This is a re-implementation for open-source use. We are still re-running some models, and their scripts, weights and logs will be released later. In the following table we report the re-run accuracy, which may be slightly different from the original paper (typically +/-0.1%) | Backbone | Decoder Layers | #frames x stride | top-1 | top-5 | Script | Model | Log | | - | - | - | - | - | - | - | - | | ViT-B/16 | 4 | 8 x 16 | 82.8 | 95.8 | [script](scripts/train_k400_vitb16_8f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1DoGjvDdkJoSa9i-wq1lh6QoEZIa4xTB3/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1-9vgsXMpnWBI9MxQV7SSQhkPfLomoYY3/view?usp=sharing) | | ViT-B/16 | 4 | 16 x 16 | 83.7 | 96.2 | [script](scripts/train_k400_vitb16_16f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1dax4qUIOEI_QzYXv31J-87cDkonQetVQ/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1l2ivY28jUpwSmafQZvwtUo7tvm42i0PL/view?usp=sharing) | | ViT-B/16 | 4 | 32 x 8 | 84.3 | 96.6 | [script](scripts/train_k400_vitb16_32f_dec4x768.sh) | [google drive](https://drive.google.com/file/d/1fzFM5pD39Kfp8xRAJuWaXR9RALLmnoeU/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1X1ZOdSCxXVeMpNhr_bviNKlRfJa5SMD7/view?usp=sharing) | | ViT-L/14 | 4 | 8 x 16 | 86.3 | 97.2 | [script](scripts/train_k400_vitl14_8f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1AkdF4CkOVW2uiycCVqCxS397oYxNISAI/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1OJFBmaE_tAwTzG-4i0CLQmhwGnN0psx1/view?usp=sharing) | | ViT-L/14 | 4 | 16 x 16 | 86.9 | 97.4 | [script](scripts/train_k400_vitl14_16f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1CTV9geLD3HLWzByAQUOf_m0F_g2lE3rg/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1a2iC4tQvjWFMI3UrEv2chuHwVrF6p9YF/view?usp=sharing) | | ViT-L/14 | 4 | 32 x 8 | 87.7 | 97.6 | [script](scripts/train_k400_vitl14_32f_dec4x1024.sh) | [google drive](https://drive.google.com/file/d/1zNFNCKwP5owakELlnTCD20cpVQBqgJrB/view?usp=sharing) | [google drive](https://drive.google.com/file/d/1dK7qoz3McYrmfS09FfreXC-LjUM7l0u4/view?usp=sharing) | | ViT-L/14 (336px) | 4 | 32 x 8 | 87.7 | 97.8 | | | | ## Data Loading Speed As the training process is fast, video frames are consumed at a very high rate. For easier installation, the current version uses PyTorch-builtin data loaders. They are not very efficient and can become a bottleneck when using ViT-B as backbones. We provide a `--dummy_dataset` option to bypass actual video decoding for training speed measurement. The model accuracy should not be affected. Our internal data loader is pure C++-based and does not bottleneck training by much on a machine with 2x Xeon Gold 6148 CPUs and 4x V100 GPUs. ## Acknowledgements The data loader code is modified from [PySlowFast](https://github.com/facebookresearch/SlowFast). Thanks for their awesome work! ================================================ FILE: checkpoint.py ================================================ #!/usr/bin/env python import argparse import os import torch import torch.distributed as dist def setup_arg_parser(parser: argparse.ArgumentParser): parser.add_argument('--checkpoint_dir', type=str, help='checkpoint output path') parser.add_argument('--auto_resume', action='store_true', help='auto resume from the last checkpoint from checkpoint_dir') parser.add_argument('--resume_path', type=str, help='resume from manually specified checkpoint file, overriding auto_resume') parser.add_argument('--pretrain', type=str, help='path to pretrained weights. will NOT override auto_resume of resume_path, ' 'load optimizer state or enforce strict matching of checkpoint and model weights.') def _find_autoresume_path(args: argparse.Namespace): print('Trying to auto resume from path:', args.checkpoint_dir) if os.path.isdir(args.checkpoint_dir): checkpoint_files = [x for x in os.listdir(args.checkpoint_dir) if x.startswith('checkpoint-') and x.endswith('.pth')] checkpoint_iters = [] for x in checkpoint_files: try: x = x[len('checkpoint-'): -len('.pth')] x = int(x) except ValueError: continue checkpoint_iters.append(x) else: checkpoint_iters = [] if len(checkpoint_iters) == 0: print('Did not find a valid checkpoint file.') else: checkpoint_iters.sort() args.resume_path = os.path.join(args.checkpoint_dir, 'checkpoint-%d.pth' % checkpoint_iters[-1]) print(f'Found {len(checkpoint_iters)} checkpoint file(s).') def resume_from_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, lr_sched: torch.optim.lr_scheduler._LRScheduler, loss_scaler: torch.cuda.amp.grad_scaler.GradScaler, args: argparse.Namespace, ) -> int: if args.pretrain is not None: print(f'Loading pretrain model: {args.pretrain}') ckpt = torch.load(args.pretrain, map_location='cpu') print(model.load_state_dict(ckpt['model'], strict=False)) # returns resume_step on successful resume, or 0 otherwise. if args.auto_resume and args.resume_path is None: _find_autoresume_path(args) if args.resume_path is None: print('Not resuming from a checkpoint.') return 0 else: print(f'Resuming from checkpoint file {args.resume_path}') ckpt = torch.load(args.resume_path, map_location='cpu') model.load_state_dict(ckpt['model'], strict=True) if 'optimizer' in ckpt: optimizer.load_state_dict(ckpt['optimizer']) lr_sched.load_state_dict(ckpt['lr_sched']) loss_scaler.load_state_dict(ckpt['loss_scaler']) return ckpt['next_step'] else: print('Optimizer state is NOT found in checkpoint.') return 0 def save_checkpoint( model: torch.nn.Module, optimizer: torch.optim.Optimizer, lr_sched: torch.optim.lr_scheduler._LRScheduler, loss_scaler: torch.cuda.amp.grad_scaler.GradScaler, next_step: int, args: argparse.Namespace, ): if args.checkpoint_dir is None: return if not os.path.isdir(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) to_save = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_sched': lr_sched.state_dict(), 'loss_scaler': loss_scaler.state_dict(), 'next_step': next_step, } torch.save(to_save, os.path.join(args.checkpoint_dir, f'checkpoint-{next_step}.pth')) ================================================ FILE: data/k400_class_mappings.json ================================================ [ "abseiling", "air drumming", "answering questions", "applauding", "applying cream", "archery", "arm wrestling", "arranging flowers", "assembling computer", "auctioning", "baby waking up", "baking cookies", "balloon blowing", "bandaging", "barbequing", "bartending", "beatboxing", "bee keeping", "belly dancing", "bench pressing", "bending back", "bending metal", "biking through snow", "blasting sand", "blowing glass", "blowing leaves", "blowing nose", "blowing out candles", "bobsledding", "bookbinding", "bouncing on trampoline", "bowling", "braiding hair", "breading or breadcrumbing", "breakdancing", "brush painting", "brushing hair", "brushing teeth", "building cabinet", "building shed", "bungee jumping", "busking", "canoeing or kayaking", "capoeira", "carrying baby", "cartwheeling", "carving pumpkin", "catching fish", "catching or throwing baseball", "catching or throwing frisbee", "catching or throwing softball", "celebrating", "changing oil", "changing wheel", "checking tires", "cheerleading", "chopping wood", "clapping", "clay pottery making", "clean and jerk", "cleaning floor", "cleaning gutters", "cleaning pool", "cleaning shoes", "cleaning toilet", "cleaning windows", "climbing a rope", "climbing ladder", "climbing tree", "contact juggling", "cooking chicken", "cooking egg", "cooking on campfire", "cooking sausages", "counting money", "country line dancing", "cracking neck", "crawling baby", "crossing river", "crying", "curling hair", "cutting nails", "cutting pineapple", "cutting watermelon", "dancing ballet", "dancing charleston", "dancing gangnam style", "dancing macarena", "deadlifting", "decorating the christmas tree", "digging", "dining", "disc golfing", "diving cliff", "dodgeball", "doing aerobics", "doing laundry", "doing nails", "drawing", "dribbling basketball", "drinking", "drinking beer", "drinking shots", "driving car", "driving tractor", "drop kicking", "drumming fingers", "dunking basketball", "dying hair", "eating burger", "eating cake", "eating carrots", "eating chips", "eating doughnuts", "eating hotdog", "eating ice cream", "eating spaghetti", "eating watermelon", "egg hunting", "exercising arm", "exercising with an exercise ball", "extinguishing fire", "faceplanting", "feeding birds", "feeding fish", "feeding goats", "filling eyebrows", "finger snapping", "fixing hair", "flipping pancake", "flying kite", "folding clothes", "folding napkins", "folding paper", "front raises", "frying vegetables", "garbage collecting", "gargling", "getting a haircut", "getting a tattoo", "giving or receiving award", "golf chipping", "golf driving", "golf putting", "grinding meat", "grooming dog", "grooming horse", "gymnastics tumbling", "hammer throw", "headbanging", "headbutting", "high jump", "high kick", "hitting baseball", "hockey stop", "holding snake", "hopscotch", "hoverboarding", "hugging", "hula hooping", "hurdling", "hurling (sport)", "ice climbing", "ice fishing", "ice skating", "ironing", "javelin throw", "jetskiing", "jogging", "juggling balls", "juggling fire", "juggling soccer ball", "jumping into pool", "jumpstyle dancing", "kicking field goal", "kicking soccer ball", "kissing", "kitesurfing", "knitting", "krumping", "laughing", "laying bricks", "long jump", "lunge", "making a cake", "making a sandwich", "making bed", "making jewelry", "making pizza", "making snowman", "making sushi", "making tea", "marching", "massaging back", "massaging feet", "massaging legs", "massaging person's head", "milking cow", "mopping floor", "motorcycling", "moving furniture", "mowing lawn", "news anchoring", "opening bottle", "opening present", "paragliding", "parasailing", "parkour", "passing American football (in game)", "passing American football (not in game)", "peeling apples", "peeling potatoes", "petting animal (not cat)", "petting cat", "picking fruit", "planting trees", "plastering", "playing accordion", "playing badminton", "playing bagpipes", "playing basketball", "playing bass guitar", "playing cards", "playing cello", "playing chess", "playing clarinet", "playing controller", "playing cricket", "playing cymbals", "playing didgeridoo", "playing drums", "playing flute", "playing guitar", "playing harmonica", "playing harp", "playing ice hockey", "playing keyboard", "playing kickball", "playing monopoly", "playing organ", "playing paintball", "playing piano", "playing poker", "playing recorder", "playing saxophone", "playing squash or racquetball", "playing tennis", "playing trombone", "playing trumpet", "playing ukulele", "playing violin", "playing volleyball", "playing xylophone", "pole vault", "presenting weather forecast", "pull ups", "pumping fist", "pumping gas", "punching bag", "punching person (boxing)", "push up", "pushing car", "pushing cart", "pushing wheelchair", "reading book", "reading newspaper", "recording music", "riding a bike", "riding camel", "riding elephant", "riding mechanical bull", "riding mountain bike", "riding mule", "riding or walking with horse", "riding scooter", "riding unicycle", "ripping paper", "robot dancing", "rock climbing", "rock scissors paper", "roller skating", "running on treadmill", "sailing", "salsa dancing", "sanding floor", "scrambling eggs", "scuba diving", "setting table", "shaking hands", "shaking head", "sharpening knives", "sharpening pencil", "shaving head", "shaving legs", "shearing sheep", "shining shoes", "shooting basketball", "shooting goal (soccer)", "shot put", "shoveling snow", "shredding paper", "shuffling cards", "side kick", "sign language interpreting", "singing", "situp", "skateboarding", "ski jumping", "skiing (not slalom or crosscountry)", "skiing crosscountry", "skiing slalom", "skipping rope", "skydiving", "slacklining", "slapping", "sled dog racing", "smoking", "smoking hookah", "snatch weight lifting", "sneezing", "sniffing", "snorkeling", "snowboarding", "snowkiting", "snowmobiling", "somersaulting", "spinning poi", "spray painting", "spraying", "springboard diving", "squat", "sticking tongue out", "stomping grapes", "stretching arm", "stretching leg", "strumming guitar", "surfing crowd", "surfing water", "sweeping floor", "swimming backstroke", "swimming breast stroke", "swimming butterfly stroke", "swing dancing", "swinging legs", "swinging on something", "sword fighting", "tai chi", "taking a shower", "tango dancing", "tap dancing", "tapping guitar", "tapping pen", "tasting beer", "tasting food", "testifying", "texting", "throwing axe", "throwing ball", "throwing discus", "tickling", "tobogganing", "tossing coin", "tossing salad", "training dog", "trapezing", "trimming or shaving beard", "trimming trees", "triple jump", "tying bow tie", "tying knot (not on a tie)", "tying tie", "unboxing", "unloading truck", "using computer", "using remote controller (not gaming)", "using segway", "vault", "waiting in line", "walking the dog", "washing dishes", "washing feet", "washing hair", "washing hands", "water skiing", "water sliding", "watering plants", "waxing back", "waxing chest", "waxing eyebrows", "waxing legs", "weaving basket", "welding", "whistling", "windsurfing", "wrapping present", "wrestling", "writing", "yawning", "yoga", "zumba" ] ================================================ FILE: main.py ================================================ #!/usr/bin/env python import argparse from datetime import datetime import builtins import torch import torch.distributed as dist import video_dataset import checkpoint from model import EVLTransformer from video_dataset import dataloader from weight_loaders import weight_loader_fn_dict from vision_transformer import vit_presets def setup_print(is_master: bool): """ This function disables printing when not in master process """ builtin_print = builtins.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: now = datetime.now().time() builtin_print('[{}] '.format(now), end='') # print with time stamp builtin_print(*args, **kwargs) builtins.print = print def main(): parser = argparse.ArgumentParser() video_dataset.setup_arg_parser(parser) checkpoint.setup_arg_parser(parser) parser.add_argument('--num_steps', type=int, help='number of training steps') parser.add_argument('--eval_only', action='store_true', help='run evaluation only') parser.add_argument('--save_freq', type=int, default=5000, help='save a checkpoint every N steps') parser.add_argument('--eval_freq', type=int, default=5000, help='evaluate every N steps') parser.add_argument('--print_freq', type=int, default=10, help='print log message every N steps') parser.add_argument('--backbone', type=str, choices=vit_presets.keys(), default='ViT-B/16-lnpre', help='the backbone variant used to generate image feature maps') parser.add_argument('--backbone_path', type=str, help='path to pretrained backbone weights') parser.add_argument('--backbone_type', type=str, default='clip', choices=weight_loader_fn_dict.keys(), help='type of backbone weights (used to determine how to convert state_dict from different pretraining codebase)') parser.add_argument('--finetune_backbone', action='store_true', help='finetune backbone weights') parser.add_argument('--decoder_num_layers', type=int, default=4, help='number of decoder layers') parser.add_argument('--decoder_qkv_dim', type=int, default=768, help='q (k, v) projection output dimensions in decoder attention layers') parser.add_argument('--decoder_num_heads', type=int, default=12, help='number of heads in decoder attention layers') parser.add_argument('--decoder_mlp_factor', type=float, default=4.0, help='expansion factor of feature dimension in the middle of decoder MLPs') parser.add_argument('--num_classes', type=int, default=400, help='number of classes') parser.add_argument('--cls_dropout', type=float, default=0.5, help='dropout rate applied before the final classification linear projection') parser.add_argument('--decoder_mlp_dropout', type=float, default=0.5, help='dropout rate applied in MLP layers in the decoder') parser.add_argument('--no_temporal_conv', action='store_false', dest='temporal_conv', help='disable temporal convolution on frame features') parser.add_argument('--no_temporal_pos_embed', action='store_false', dest='temporal_pos_embed', help='disable temporal position embeddings added to frame features') parser.add_argument('--no_temporal_cross_attention', action='store_false', dest='temporal_cross_attention', help='disable temporal cross attention on frame query and key features') parser.set_defaults(temporal_conv=True, temporal_pos_embed=True, temporal_cross_attention=True) parser.add_argument('--lr', type=float, default=4e-4, help='learning rate') parser.add_argument('--weight_decay', type=float, default=0.05, help='optimizer weight decay') parser.add_argument('--disable_fp16', action='store_false', dest='fp16', help='disable fp16 during training or inference') parser.set_defaults(fp16=True) parser.add_argument('--batch_split', type=int, default=1, help='optionally split the batch into smaller shards and forward/backward one shard ' 'at a time to avoid out-of-memory error.') args = parser.parse_args() dist.init_process_group('nccl') setup_print(dist.get_rank() == 0) cuda_device_id = dist.get_rank() % torch.cuda.device_count() torch.cuda.set_device(cuda_device_id) model = EVLTransformer( backbone_name=args.backbone, backbone_type=args.backbone_type, backbone_path=args.backbone_path, backbone_mode='finetune' if args.finetune_backbone else ('freeze_fp16' if args.fp16 else 'freeze_fp32'), decoder_num_layers=args.decoder_num_layers, decoder_qkv_dim=args.decoder_qkv_dim, decoder_num_heads=args.decoder_num_heads, decoder_mlp_factor=args.decoder_mlp_factor, num_classes=args.num_classes, enable_temporal_conv=args.temporal_conv, enable_temporal_pos_embed=args.temporal_pos_embed, enable_temporal_cross_attention=args.temporal_cross_attention, cls_dropout=args.cls_dropout, decoder_mlp_dropout=args.decoder_mlp_dropout, num_frames=args.num_frames, ) print(model) model.cuda() model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[cuda_device_id], output_device=cuda_device_id, ) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_steps) loss_scaler = torch.cuda.amp.grad_scaler.GradScaler(enabled=args.fp16) criterion = torch.nn.CrossEntropyLoss() resume_step = checkpoint.resume_from_checkpoint(model, optimizer, lr_sched, loss_scaler, args) val_loader = video_dataset.create_val_loader(args) if args.eval_only: print('Running in eval_only mode.') model.eval() evaluate(model, val_loader) return else: assert args.train_list_path is not None, 'Train list path must be specified if not in eval_only mode.' train_loader = video_dataset.create_train_loader(args, resume_step=resume_step) assert len(train_loader) == args.num_steps - resume_step batch_st, train_st = datetime.now(), datetime.now() for i, (data, labels) in enumerate(train_loader, resume_step): data, labels = data.cuda(), labels.cuda() data_ed = datetime.now() optimizer.zero_grad() assert data.size(0) % args.batch_split == 0 split_size = data.size(0) // args.batch_split hit1, hit5, loss_value = 0, 0, 0 for j in range(args.batch_split): data_slice = data[split_size * j: split_size * (j + 1)] labels_slice = labels[split_size * j: split_size * (j + 1)] with torch.cuda.amp.autocast(args.fp16): logits = model(data_slice) loss = criterion(logits, labels_slice) if labels.dtype == torch.long: # no mixup, can calculate accuracy hit1 += (logits.topk(1, dim=1)[1] == labels_slice.view(-1, 1)).sum().item() hit5 += (logits.topk(5, dim=1)[1] == labels_slice.view(-1, 1)).sum().item() loss_value += loss.item() / args.batch_split loss_scaler.scale(loss / args.batch_split).backward() loss_scaler.step(optimizer) loss_scaler.update() lr_sched.step() batch_ed = datetime.now() if i % args.print_freq == 0: sync_tensor = torch.Tensor([loss_value, hit1 / data.size(0), hit5 / data.size(0)]).cuda() dist.all_reduce(sync_tensor) sync_tensor = sync_tensor.cpu() / dist.get_world_size() loss_value, acc1, acc5 = sync_tensor.tolist() print( f'batch_time: {(batch_ed - batch_st).total_seconds():.3f} ' f'data_time: {(data_ed - batch_st).total_seconds():.3f} ' f'ETA: {(batch_ed - train_st) / (i - resume_step + 1) * (args.num_steps - i - 1)} | ' f'lr: {optimizer.param_groups[0]["lr"]:.6f} ' f'loss: {loss_value:.6f}' + ( f' acc1: {acc1 * 100:.2f}% acc5: {acc5 * 100:.2f}%' if labels.dtype == torch.long else '' ) ) if (i + 1) % args.eval_freq == 0: print('Start model evaluation at step', i + 1) model.eval() evaluate(model, val_loader) model.train() if (i + 1) % args.save_freq == 0: checkpoint.save_checkpoint(model, optimizer, lr_sched, loss_scaler, i + 1, args) batch_st = datetime.now() def evaluate(model: torch.nn.Module, loader: torch.utils.data.DataLoader): tot, hit1, hit5 = 0, 0, 0 eval_st = datetime.now() for data, labels in loader: data, labels = data.cuda(), labels.cuda() assert data.size(0) == 1 if data.ndim == 6: data = data[0] # now the first dimension is number of views with torch.no_grad(): logits = model(data) scores = logits.softmax(dim=-1).mean(dim=0) tot += 1 hit1 += (scores.topk(1)[1] == labels).sum().item() hit5 += (scores.topk(5)[1] == labels).sum().item() if tot % 20 == 0: print(f'[Evaluation] num_samples: {tot} ' f'ETA: {(datetime.now() - eval_st) / tot * (len(loader) - tot)} ' f'cumulative_acc1: {hit1 / tot * 100.:.2f}% ' f'cumulative_acc5: {hit5 / tot * 100.:.2f}%') sync_tensor = torch.LongTensor([tot, hit1, hit5]).cuda() dist.all_reduce(sync_tensor) tot, hit1, hit5 = sync_tensor.cpu().tolist() print(f'Accuracy on validation set: top1={hit1 / tot * 100:.2f}%, top5={hit5 / tot * 100:.2f}%') if __name__ == '__main__': main() ================================================ FILE: model.py ================================================ #!/usr/bin/env python from typing import Dict, Iterable, List, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from vision_transformer import QuickGELU, Attention from weight_loaders import weight_loader_fn_dict from vision_transformer import ( VisionTransformer2D, TransformerDecoderLayer, model_to_fp16, vit_presets, ) class TemporalCrossAttention(nn.Module): def __init__( self, spatial_size: Tuple[int, int] = (14, 14), feature_dim: int = 768, ): super().__init__() self.spatial_size = spatial_size w_size = np.prod([x * 2 - 1 for x in spatial_size]) self.w1 = nn.Parameter(torch.zeros([w_size, feature_dim])) self.w2 = nn.Parameter(torch.zeros([w_size, feature_dim])) idx_tensor = torch.zeros([np.prod(spatial_size) for _ in (0, 1)], dtype=torch.long) for q in range(np.prod(spatial_size)): qi, qj = q // spatial_size[1], q % spatial_size[1] for k in range(np.prod(spatial_size)): ki, kj = k // spatial_size[1], k % spatial_size[1] i_offs = qi - ki + spatial_size[0] - 1 j_offs = qj - kj + spatial_size[1] - 1 idx_tensor[q, k] = i_offs * (spatial_size[1] * 2 - 1) + j_offs self.idx_tensor = idx_tensor def forward_half(self, q: torch.Tensor, k: torch.Tensor, w: torch.Tensor) -> torch.Tensor: q, k = q[:, :, 1:], k[:, :, 1:] # remove cls token assert q.size() == k.size() assert q.size(2) == np.prod(self.spatial_size) attn = torch.einsum('ntqhd,ntkhd->ntqkh', q / (q.size(-1) ** 0.5), k) attn = attn.softmax(dim=-2).mean(dim=-1) # L, L, N, T self.idx_tensor = self.idx_tensor.to(w.device) w_unroll = w[self.idx_tensor] # L, L, C ret = torch.einsum('ntqk,qkc->ntqc', attn, w_unroll) return ret def forward(self, q: torch.Tensor, k: torch.Tensor): N, T, L, H, D = q.size() assert L == np.prod(self.spatial_size) + 1 ret = torch.zeros([N, T, L, self.w1.size(-1)], device='cuda') ret[:, 1:, 1:, :] += self.forward_half(q[:, 1:, :, :, :], k[:, :-1, :, :, :], self.w1) ret[:, :-1, 1:, :] += self.forward_half(q[:, :-1, :, :, :], k[:, 1:, :, :, :], self.w2) return ret class EVLDecoder(nn.Module): def __init__( self, num_frames: int = 8, spatial_size: Tuple[int, int] = (14, 14), num_layers: int = 4, in_feature_dim: int = 768, qkv_dim: int = 768, num_heads: int = 12, mlp_factor: float = 4.0, enable_temporal_conv: bool = True, enable_temporal_pos_embed: bool = True, enable_temporal_cross_attention: bool = True, mlp_dropout: float = 0.5, ): super().__init__() self.enable_temporal_conv = enable_temporal_conv self.enable_temporal_pos_embed = enable_temporal_pos_embed self.enable_temporal_cross_attention = enable_temporal_cross_attention self.num_layers = num_layers self.decoder_layers = nn.ModuleList( [TransformerDecoderLayer(in_feature_dim, qkv_dim, num_heads, mlp_factor, mlp_dropout) for _ in range(num_layers)] ) if enable_temporal_conv: self.temporal_conv = nn.ModuleList( [nn.Conv1d(in_feature_dim, in_feature_dim, kernel_size=3, stride=1, padding=1, groups=in_feature_dim) for _ in range(num_layers)] ) if enable_temporal_pos_embed: self.temporal_pos_embed = nn.ParameterList( [nn.Parameter(torch.zeros([num_frames, in_feature_dim])) for _ in range(num_layers)] ) if enable_temporal_cross_attention: self.cross_attention = nn.ModuleList( [TemporalCrossAttention(spatial_size, in_feature_dim) for _ in range(num_layers)] ) self.cls_token = nn.Parameter(torch.zeros([in_feature_dim])) def _initialize_weights(self): nn.init.normal_(self.cls_token, std=0.02) def forward(self, in_features: List[Dict[str, torch.Tensor]]): N, T, L, C = in_features[0]['out'].size() assert len(in_features) == self.num_layers x = self.cls_token.view(1, 1, -1).repeat(N, 1, 1) for i in range(self.num_layers): frame_features = in_features[i]['out'] if self.enable_temporal_conv: feat = in_features[i]['out'] feat = feat.permute(0, 2, 3, 1).contiguous().flatten(0, 1) # N * L, C, T feat = self.temporal_conv[i](feat) feat = feat.view(N, L, C, T).permute(0, 3, 1, 2).contiguous() # N, T, L, C frame_features += feat if self.enable_temporal_pos_embed: frame_features += self.temporal_pos_embed[i].view(1, T, 1, C) if self.enable_temporal_cross_attention: frame_features += self.cross_attention[i](in_features[i]['q'], in_features[i]['k']) frame_features = frame_features.flatten(1, 2) # N, T * L, C x = self.decoder_layers[i](x, frame_features) return x class EVLTransformer(nn.Module): def __init__( self, num_frames: int = 8, backbone_name: str = 'ViT-B/16', backbone_type: str = 'clip', backbone_path: str = '', backbone_mode: str = 'frozen_fp16', decoder_num_layers: int = 4, decoder_qkv_dim: int = 768, decoder_num_heads: int = 12, decoder_mlp_factor: float = 4.0, num_classes: int = 400, enable_temporal_conv: bool = True, enable_temporal_pos_embed: bool = True, enable_temporal_cross_attention: bool = True, cls_dropout: float = 0.5, decoder_mlp_dropout: float = 0.5, ): super().__init__() self.decoder_num_layers = decoder_num_layers backbone_config = self._create_backbone(backbone_name, backbone_type, backbone_path, backbone_mode) backbone_feature_dim = backbone_config['feature_dim'] backbone_spatial_size = tuple(x // y for x, y in zip(backbone_config['input_size'], backbone_config['patch_size'])) self.decoder = EVLDecoder( num_frames=num_frames, spatial_size=backbone_spatial_size, num_layers=decoder_num_layers, in_feature_dim=backbone_feature_dim, qkv_dim=decoder_qkv_dim, num_heads=decoder_num_heads, mlp_factor=decoder_mlp_factor, enable_temporal_conv=enable_temporal_conv, enable_temporal_pos_embed=enable_temporal_pos_embed, enable_temporal_cross_attention=enable_temporal_cross_attention, mlp_dropout=decoder_mlp_dropout, ) self.proj = nn.Sequential( nn.LayerNorm(backbone_feature_dim), nn.Dropout(cls_dropout), nn.Linear(backbone_feature_dim, num_classes), ) def _create_backbone( self, backbone_name: str, backbone_type: str, backbone_path: str, backbone_mode: str, ) -> dict: weight_loader_fn = weight_loader_fn_dict[backbone_type] state_dict = weight_loader_fn(backbone_path) backbone = VisionTransformer2D(return_all_features=True, **vit_presets[backbone_name]) backbone.load_state_dict(state_dict, strict=True) # weight_loader_fn is expected to strip unused parameters assert backbone_mode in ['finetune', 'freeze_fp16', 'freeze_fp32'] if backbone_mode == 'finetune': self.backbone = backbone else: backbone.eval().requires_grad_(False) if backbone_mode == 'freeze_fp16': model_to_fp16(backbone) self.backbone = [backbone] # avoid backbone parameter registration return vit_presets[backbone_name] def _get_backbone(self, x): if isinstance(self.backbone, list): # freeze backbone self.backbone[0] = self.backbone[0].to(x.device) return self.backbone[0] else: # finetune bakbone return self.backbone def forward(self, x: torch.Tensor): backbone = self._get_backbone(x) B, C, T, H, W = x.size() x = x.permute(0, 2, 1, 3, 4).flatten(0, 1) features = backbone(x)[-self.decoder_num_layers:] features = [ dict((k, v.float().view(B, T, *v.size()[1:])) for k, v in x.items()) for x in features ] x = self.decoder(features) x = self.proj(x[:, 0, :]) return x ================================================ FILE: scripts/eval_k400_vitb16_16f_dec4x768.sh ================================================ #!/usr/bin/env sh python -u -m torch.distributed.run --nproc_per_node 4 \ main.py \ --num_steps 50000 \ --backbone "ViT-B/16-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-B-16.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 768 \ --decoder_num_heads 12 \ --num_classes 400 \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 16 \ --sampling_rate 16 \ --num_spatial_views 3 \ --num_temporal_views 1 \ --resume_path /path/to/checkpoint_release/k400_vitb16_16f_dec4x768.pth \ --eval_only ================================================ FILE: scripts/eval_k400_vitb16_32f_dec4x768.sh ================================================ #!/usr/bin/env sh python -u -m torch.distributed.run --nproc_per_node 4 \ main.py \ --num_steps 50000 \ --backbone "ViT-B/16-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-B-16.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 768 \ --decoder_num_heads 12 \ --num_classes 400 \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 32 \ --sampling_rate 8 \ --num_spatial_views 3 \ --num_temporal_views 1 \ --resume_path /path/to/checkpoint_release/k400_vitb16_32f_dec4x768.pth \ --eval_only ================================================ FILE: scripts/eval_k400_vitb16_8f_dec4x768.sh ================================================ #!/usr/bin/env sh python -u -m torch.distributed.run --nproc_per_node 4 \ main.py \ --num_steps 50000 \ --backbone "ViT-B/16-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-B-16.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 768 \ --decoder_num_heads 12 \ --num_classes 400 \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 8 \ --sampling_rate 16 \ --num_spatial_views 1 \ --num_temporal_views 3 \ --resume_path /path/to/checkpoint_release/k400_vitb16_8f_dec4x768.pth \ --eval_only ================================================ FILE: scripts/eval_k400_vitl14_16f_dec4x1024.sh ================================================ #!/usr/bin/env sh python -u -m torch.distributed.run --nproc_per_node 4 \ main.py \ --num_steps 50000 \ --backbone "ViT-L/14-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-L-14.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 1024 \ --decoder_num_heads 16 \ --num_classes 400 \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 16 \ --sampling_rate 16 \ --num_spatial_views 3 \ --num_temporal_views 1 \ --resume_path /path/to/checkpoint_release/k400_vitl14_16f_dec4x1024.pth \ --eval_only ================================================ FILE: scripts/eval_k400_vitl14_32f_dec4x1024.sh ================================================ #!/usr/bin/env sh python -u -m torch.distributed.run --nproc_per_node 4 \ main.py \ --num_steps 50000 \ --backbone "ViT-L/14-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-L-14.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 1024 \ --decoder_num_heads 16 \ --num_classes 400 \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 32 \ --sampling_rate 8 \ --num_spatial_views 3 \ --num_temporal_views 1 \ --resume_path /path/to/checkpoint_release/k400_vitl14_32f_dec4x1024.pth \ --eval_only ================================================ FILE: scripts/eval_k400_vitl14_8f_dec4x1024.sh ================================================ #!/usr/bin/env sh python -u -m torch.distributed.run --nproc_per_node 4 \ main.py \ --num_steps 50000 \ --backbone "ViT-L/14-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-L-14.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 1024 \ --decoder_num_heads 16 \ --num_classes 400 \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 8 \ --sampling_rate 16 \ --num_spatial_views 1 \ --num_temporal_views 3 \ --resume_path /path/to/checkpoint_release/k400_vitl14_8f_dec4x1024.pth \ --eval_only ================================================ FILE: scripts/train_k400_vitb16_16f_dec4x768.sh ================================================ #!/usr/bin/env sh exp_dir=runs/k400_vitb16_16f_dec4x768 mkdir -p "${exp_dir}" python -u -m torch.distributed.run --nproc_per_node 8 \ main.py \ --num_steps 50000 \ --backbone "ViT-B/16-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-B-16.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 768 \ --decoder_num_heads 12 \ --num_classes 400 \ --checkpoint_dir "${exp_dir}" \ --auto_resume \ --train_list_path /path/to/k400/train.txt \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 16 \ --sampling_rate 16 \ --num_spatial_views 3 \ --num_temporal_views 1 \ 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" ================================================ FILE: scripts/train_k400_vitb16_32f_dec4x768.sh ================================================ #!/usr/bin/env sh exp_dir=runs/k400_vitb16_32f_dec4x768 mkdir -p "${exp_dir}" python -u -m torch.distributed.run --nproc_per_node 8 \ main.py \ --num_steps 50000 \ --backbone "ViT-B/16-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-B-16.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 768 \ --decoder_num_heads 12 \ --num_classes 400 \ --checkpoint_dir "${exp_dir}" \ --auto_resume \ --train_list_path /path/to/k400/train.txt \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 32 \ --sampling_rate 8 \ --num_spatial_views 3 \ --num_temporal_views 1 \ 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" ================================================ FILE: scripts/train_k400_vitb16_8f_dec4x768.sh ================================================ #!/usr/bin/env sh exp_dir=runs/k400_vitb16_8f_dec4x768 mkdir -p "${exp_dir}" python -u -m torch.distributed.run --nproc_per_node 8 \ main.py \ --num_steps 50000 \ --backbone "ViT-B/16-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-B-16.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 768 \ --decoder_num_heads 12 \ --num_classes 400 \ --checkpoint_dir "${exp_dir}" \ --auto_resume \ --train_list_path /path/to/k400/train.txt \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 8 \ --sampling_rate 16 \ --num_spatial_views 1 \ --num_temporal_views 3 \ 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" ================================================ FILE: scripts/train_k400_vitl14_16f_dec4x1024.sh ================================================ #!/usr/bin/env sh exp_dir=runs/k400_vitl14_16f_dec4x1024 mkdir -p "${exp_dir}" python -u -m torch.distributed.run --nproc_per_node 8 \ main.py \ --num_steps 50000 \ --backbone "ViT-L/14-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-L-14.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 1024 \ --decoder_num_heads 16 \ --num_classes 400 \ --checkpoint_dir "${exp_dir}" \ --auto_resume \ --train_list_path /path/to/k400/train.txt \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 2 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 16 \ --sampling_rate 16 \ --num_spatial_views 3 \ --num_temporal_views 1 \ 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" ================================================ FILE: scripts/train_k400_vitl14_32f_dec4x1024.sh ================================================ #!/usr/bin/env sh exp_dir=runs/k400_vitl14_32f_dec4x1024 mkdir -p "${exp_dir}" python -u -m torch.distributed.run --nproc_per_node 8 \ main.py \ --num_steps 50000 \ --backbone "ViT-L/14-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-L-14.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 1024 \ --decoder_num_heads 16 \ --num_classes 400 \ --checkpoint_dir "${exp_dir}" \ --auto_resume \ --train_list_path /path/to/k400/train.txt \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 4 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 32 \ --sampling_rate 8 \ --num_spatial_views 3 \ --num_temporal_views 1 \ 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" ================================================ FILE: scripts/train_k400_vitl14_8f_dec4x1024.sh ================================================ #!/usr/bin/env sh exp_dir=runs/k400_vitl14_8f_dec4x1024 mkdir -p "${exp_dir}" python -u -m torch.distributed.run --nproc_per_node 8 \ main.py \ --num_steps 50000 \ --backbone "ViT-L/14-lnpre" \ --backbone_type clip \ --backbone_path /path/to/clip_models/ViT-L-14.pt \ --decoder_num_layers 4 \ --decoder_qkv_dim 1024 \ --decoder_num_heads 16 \ --num_classes 400 \ --checkpoint_dir "${exp_dir}" \ --auto_resume \ --train_list_path /path/to/k400/train.txt \ --val_list_path /path/to/k400/val.txt \ --batch_size 256 \ --batch_split 1 \ --auto_augment rand-m7-n4-mstd0.5-inc1 \ --mean 0.48145466 0.4578275 0.40821073 \ --std 0.26862954 0.26130258 0.27577711 \ --num_workers 12 \ --num_frames 8 \ --sampling_rate 16 \ --num_spatial_views 1 \ --num_temporal_views 3 \ 2>&1 | tee "${exp_dir}/train-$(date +"%Y%m%d_%H%M%S").log" ================================================ FILE: video_dataset/__init__.py ================================================ #!/usr/bin/env python from .dataloader import setup_arg_parser, create_train_loader, create_val_loader ================================================ FILE: video_dataset/dataloader.py ================================================ #!/usr/bin/env python import argparse from typing import Dict import torch import torch.distributed as dist from .dataset import VideoDataset, DummyDataset def setup_arg_parser(parser: argparse.ArgumentParser): parser.add_argument('--train_list_path', type=str, help='path to training data list') parser.add_argument('--val_list_path', type=str, help='path to validation data list') parser.add_argument('--train_data_root', type=str, help='training samples root directory') parser.add_argument('--val_data_root', type=str, help='validation samples root directory') parser.add_argument('--data_root', type=str, default='', help='training and validation samples root directory, might be overrided by --train_data_root or --val_data_root') parser.add_argument('--batch_size', type=int, help='training batch size on a all GPUs') parser.add_argument('--num_spatial_views', type=int, default=1, help='number of spatial crops used for testing (total views = num_spatial_views * num_temporal_views)') parser.add_argument('--num_temporal_views', type=int, default=3, help='number of temporal crops used for testing (total views = num_spatial_views * num_temporal_views)') parser.add_argument('--num_frames', type=int, default=8, help='number of frames used for each view') parser.add_argument('--sampling_rate', type=int, default=16, help='temporal stride for frame sampling, only valid when tsn_sampling is not enabled') parser.add_argument('--tsn_sampling', action='store_true', help='enable TSN-style sampling (i.e. sample frames with dynamic stride to cover the whole video)') parser.add_argument('--spatial_size', type=int, default=224, help='frame height and width in pixels') parser.add_argument('--mean', type=float, nargs='+', help='pixel mean used to normalize the image.') parser.add_argument('--std', type=float, nargs='+', help='pixel std used to normalize the image') parser.add_argument('--num_workers', type=int, default=10, help='number of DataLoader worker threads') parser.add_argument('--dummy_dataset', action='store_true', help='use fake datasets that generate all 0 (use for speed test only)') parser.add_argument('--auto_augment', type=str, help='auto augment configuration') parser.add_argument('--interpolation', type=str, default='bicubic', help='interpolation mode') parser.add_argument('--no_mirror', action='store_false', dest='mirror', help='disable mirror for training (frequently used for the something-something dataset)') parser.set_defaults(mirror=True) def _parse_mean_and_std(args: argparse.Namespace) -> Dict[str, torch.Tensor]: def parse_mean_or_std(arg, default_value): if arg is None: return torch.Tensor([default_value] * 3) elif len(arg) == 1: return torch.Tensor(arg * 3) elif len(arg) == 3: return torch.Tensor(arg) else: raise NotImplementedError() return { 'mean': parse_mean_or_std(args.mean, 0.45), 'std': parse_mean_or_std(args.std, 0.225), } def create_train_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset: if args.dummy_dataset: return DummyDataset( list_path=args.train_list_path, num_frames=args.num_frames, num_views=1, spatial_size=args.spatial_size, ) return VideoDataset( list_path=args.train_list_path, data_root=args.train_data_root or args.data_root, num_spatial_views=1, num_temporal_views=1, random_sample=True, auto_augment=args.auto_augment, interpolation=args.interpolation, mirror=args.mirror, num_frames=args.num_frames, sampling_rate=-1 if args.tsn_sampling else args.sampling_rate, spatial_size=args.spatial_size, **_parse_mean_and_std(args), ) def create_train_loader(args: argparse.Namespace, resume_step: int = 0) -> torch.utils.data.DataLoader: dataset = create_train_dataset(args) rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size()) assert args.batch_size % world_size == 0 batch_size_per_gpu = args.batch_size // world_size # manually create a step-based sampler sampler = [] while len(sampler) * len(dataset) < args.num_steps * args.batch_size: g = torch.Generator() g.manual_seed(len(sampler)) indices = torch.randperm(len(dataset), generator=g) sampler.append(indices) sampler = torch.cat(sampler, dim=0)[:args.num_steps * args.batch_size].view(args.num_steps, args.batch_size) sampler = sampler[resume_step:, batch_size_per_gpu * rank: batch_size_per_gpu * (rank + 1)].flatten().tolist() loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=batch_size_per_gpu, num_workers=args.num_workers, pin_memory=False, drop_last=True, ) return loader def create_val_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset: if args.dummy_dataset: return DummyDataset( list_path=args.val_list_path, num_frames=args.num_frames, num_views=args.num_spatial_views * args.num_temporal_views, spatial_size=args.spatial_size, ) return VideoDataset( list_path=args.val_list_path, data_root=args.val_data_root or args.data_root, num_spatial_views=args.num_spatial_views, num_temporal_views=args.num_temporal_views, random_sample=False, num_frames=args.num_frames, sampling_rate=-1 if args.tsn_sampling else args.sampling_rate, spatial_size=args.spatial_size, **_parse_mean_and_std(args), ) def create_val_loader(args: argparse.Namespace) -> torch.utils.data.Dataset: dataset = create_val_dataset(args) rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size()) # sampler for distribued eval sampler = list(range(rank, len(dataset), world_size)) loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=1, num_workers=args.num_workers, pin_memory=False, ) return loader ================================================ FILE: video_dataset/dataset.py ================================================ #!/usr/bin/env python import os, sys from typing import Optional import av import io import numpy as np import torch from torchvision import transforms from .transform import create_random_augment, random_resized_crop class VideoDataset(torch.utils.data.Dataset): def __init__( self, list_path: str, data_root: str, num_spatial_views: int, num_temporal_views: int, random_sample: bool, num_frames: int, sampling_rate: int, spatial_size: int, mean: torch.Tensor, std: torch.Tensor, auto_augment: Optional[str] = None, interpolation: str = 'bicubic', mirror: bool = False, ): self.data_root = data_root self.interpolation = interpolation self.spatial_size = spatial_size self.mean, self.std = mean, std self.num_frames, self.sampling_rate = num_frames, sampling_rate if random_sample: assert num_spatial_views == 1 and num_temporal_views == 1 self.random_sample = True self.mirror = mirror self.auto_augment = auto_augment else: assert auto_augment is None and not mirror self.random_sample = False self.num_temporal_views = num_temporal_views self.num_spatial_views = num_spatial_views with open(list_path) as f: self.data_list = f.read().splitlines() def __len__(self): return len(self.data_list) def __getitem__(self, idx): line = self.data_list[idx] path, label = line.split(' ') path = os.path.join(self.data_root, path) label = int(label) container = av.open(path) frames = {} for frame in container.decode(video=0): frames[frame.pts] = frame container.close() frames = [frames[k] for k in sorted(frames.keys())] if self.random_sample: frame_idx = self._random_sample_frame_idx(len(frames)) frames = [frames[x].to_rgb().to_ndarray() for x in frame_idx] frames = torch.as_tensor(np.stack(frames)).float() / 255. if self.auto_augment is not None: aug_transform = create_random_augment( input_size=(frames.size(1), frames.size(2)), auto_augment=self.auto_augment, interpolation=self.interpolation, ) frames = frames.permute(0, 3, 1, 2) # T, C, H, W frames = [transforms.ToPILImage()(frames[i]) for i in range(frames.size(0))] frames = aug_transform(frames) frames = torch.stack([transforms.ToTensor()(img) for img in frames]) frames = frames.permute(0, 2, 3, 1) frames = (frames - self.mean) / self.std frames = frames.permute(3, 0, 1, 2) # C, T, H, W frames = random_resized_crop( frames, self.spatial_size, self.spatial_size, ) else: frames = [x.to_rgb().to_ndarray() for x in frames] frames = torch.as_tensor(np.stack(frames)) frames = frames.float() / 255. frames = (frames - self.mean) / self.std frames = frames.permute(3, 0, 1, 2) # C, T, H, W if frames.size(-2) < frames.size(-1): new_width = frames.size(-1) * self.spatial_size // frames.size(-2) new_height = self.spatial_size else: new_height = frames.size(-2) * self.spatial_size // frames.size(-1) new_width = self.spatial_size frames = torch.nn.functional.interpolate( frames, size=(new_height, new_width), mode='bilinear', align_corners=False, ) frames = self._generate_spatial_crops(frames) frames = sum([self._generate_temporal_crops(x) for x in frames], []) if len(frames) > 1: frames = torch.stack(frames) return frames, label def _generate_temporal_crops(self, frames): seg_len = (self.num_frames - 1) * self.sampling_rate + 1 if frames.size(1) < seg_len: frames = torch.cat([frames, frames[:, -1:].repeat(1, seg_len - frames.size(1), 1, 1)], dim=1) slide_len = frames.size(1) - seg_len crops = [] for i in range(self.num_temporal_views): if self.num_temporal_views == 1: st = slide_len // 2 else: st = round(slide_len / (self.num_temporal_views - 1) * i) crops.append(frames[:, st: st + self.num_frames * self.sampling_rate: self.sampling_rate]) return crops def _generate_spatial_crops(self, frames): if self.num_spatial_views == 1: assert min(frames.size(-2), frames.size(-1)) >= self.spatial_size h_st = (frames.size(-2) - self.spatial_size) // 2 w_st = (frames.size(-1) - self.spatial_size) // 2 h_ed, w_ed = h_st + self.spatial_size, w_st + self.spatial_size return [frames[:, :, h_st: h_ed, w_st: w_ed]] elif self.num_spatial_views == 3: assert min(frames.size(-2), frames.size(-1)) == self.spatial_size crops = [] margin = max(frames.size(-2), frames.size(-1)) - self.spatial_size for st in (0, margin // 2, margin): ed = st + self.spatial_size if frames.size(-2) > frames.size(-1): crops.append(frames[:, :, st: ed, :]) else: crops.append(frames[:, :, :, st: ed]) return crops else: raise NotImplementedError() def _random_sample_frame_idx(self, len): frame_indices = [] if self.sampling_rate < 0: # tsn sample seg_size = (len - 1) / self.num_frames for i in range(self.num_frames): start, end = round(seg_size * i), round(seg_size * (i + 1)) frame_indices.append(np.random.randint(start, end + 1)) elif self.sampling_rate * (self.num_frames - 1) + 1 >= len: for i in range(self.num_frames): frame_indices.append(i * self.sampling_rate if i * self.sampling_rate < len else frame_indices[-1]) else: start = np.random.randint(len - self.sampling_rate * (self.num_frames - 1)) frame_indices = list(range(start, start + self.sampling_rate * self.num_frames, self.sampling_rate)) return frame_indices class DummyDataset(torch.utils.data.Dataset): def __init__(self, list_path: str, num_frames: int, num_views: int, spatial_size: int): with open(list_path) as f: self.len = len(f.read().splitlines()) self.num_frames = num_frames self.num_views = num_views self.spatial_size = spatial_size def __len__(self): return self.len def __getitem__(self, _): shape = [3, self.num_frames, self.spatial_size, self.spatial_size] if self.num_views != 1: shape = [self.num_views] + shape return torch.zeros(shape), 0 ================================================ FILE: video_dataset/rand_augment.py ================================================ #!/usr/bin/env python # Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/rand_augment.py # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ This implementation is based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py pulished under an Apache License 2.0. COMMENT FROM ORIGINAL: AutoAugment, RandAugment, and AugMix for PyTorch This code implements the searched ImageNet policies with various tweaks and improvements and does not include any of the search code. AA and RA Implementation adapted from: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py AugMix adapted from: https://github.com/google-research/augmix Papers: AutoAugment: Learning Augmentation Policies from Data https://arxiv.org/abs/1805.09501 Learning Data Augmentation Strategies for Object Detection https://arxiv.org/abs/1906.11172 RandAugment: Practical automated data augmentation... https://arxiv.org/abs/1909.13719 AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty https://arxiv.org/abs/1912.02781 Hacked together by / Copyright 2020 Ross Wightman """ import math import numpy as np import random import re import PIL from PIL import Image, ImageEnhance, ImageOps _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) _FILL = (128, 128, 128) # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. _MAX_LEVEL = 10.0 _HPARAMS_DEFAULT = { "translate_const": 250, "img_mean": _FILL, } _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) def _interpolation(kwargs): interpolation = kwargs.pop("resample", Image.BILINEAR) if isinstance(interpolation, (list, tuple)): return random.choice(interpolation) else: return interpolation def _check_args_tf(kwargs): if "fillcolor" in kwargs and _PIL_VER < (5, 0): kwargs.pop("fillcolor") kwargs["resample"] = _interpolation(kwargs) def shear_x(img, factor, **kwargs): _check_args_tf(kwargs) return img.transform( img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs ) def shear_y(img, factor, **kwargs): _check_args_tf(kwargs) return img.transform( img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs ) def translate_x_rel(img, pct, **kwargs): pixels = pct * img.size[0] _check_args_tf(kwargs) return img.transform( img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs ) def translate_y_rel(img, pct, **kwargs): pixels = pct * img.size[1] _check_args_tf(kwargs) return img.transform( img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs ) def translate_x_abs(img, pixels, **kwargs): _check_args_tf(kwargs) return img.transform( img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs ) def translate_y_abs(img, pixels, **kwargs): _check_args_tf(kwargs) return img.transform( img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs ) def rotate(img, degrees, **kwargs): _check_args_tf(kwargs) if _PIL_VER >= (5, 2): return img.rotate(degrees, **kwargs) elif _PIL_VER >= (5, 0): w, h = img.size post_trans = (0, 0) rotn_center = (w / 2.0, h / 2.0) angle = -math.radians(degrees) matrix = [ round(math.cos(angle), 15), round(math.sin(angle), 15), 0.0, round(-math.sin(angle), 15), round(math.cos(angle), 15), 0.0, ] def transform(x, y, matrix): (a, b, c, d, e, f) = matrix return a * x + b * y + c, d * x + e * y + f matrix[2], matrix[5] = transform( -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix, ) matrix[2] += rotn_center[0] matrix[5] += rotn_center[1] return img.transform(img.size, Image.AFFINE, matrix, **kwargs) else: return img.rotate(degrees, resample=kwargs["resample"]) def auto_contrast(img, **__): return ImageOps.autocontrast(img) def invert(img, **__): return ImageOps.invert(img) def equalize(img, **__): return ImageOps.equalize(img) def solarize(img, thresh, **__): return ImageOps.solarize(img, thresh) def solarize_add(img, add, thresh=128, **__): lut = [] for i in range(256): if i < thresh: lut.append(min(255, i + add)) else: lut.append(i) if img.mode in ("L", "RGB"): if img.mode == "RGB" and len(lut) == 256: lut = lut + lut + lut return img.point(lut) else: return img def posterize(img, bits_to_keep, **__): if bits_to_keep >= 8: return img return ImageOps.posterize(img, bits_to_keep) def contrast(img, factor, **__): return ImageEnhance.Contrast(img).enhance(factor) def color(img, factor, **__): return ImageEnhance.Color(img).enhance(factor) def brightness(img, factor, **__): return ImageEnhance.Brightness(img).enhance(factor) def sharpness(img, factor, **__): return ImageEnhance.Sharpness(img).enhance(factor) def _randomly_negate(v): """With 50% prob, negate the value""" return -v if random.random() > 0.5 else v def _rotate_level_to_arg(level, _hparams): # range [-30, 30] level = (level / _MAX_LEVEL) * 30.0 level = _randomly_negate(level) return (level,) def _enhance_level_to_arg(level, _hparams): # range [0.1, 1.9] return ((level / _MAX_LEVEL) * 1.8 + 0.1,) def _enhance_increasing_level_to_arg(level, _hparams): # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend # range [0.1, 1.9] level = (level / _MAX_LEVEL) * 0.9 level = 1.0 + _randomly_negate(level) return (level,) def _shear_level_to_arg(level, _hparams): # range [-0.3, 0.3] level = (level / _MAX_LEVEL) * 0.3 level = _randomly_negate(level) return (level,) def _translate_abs_level_to_arg(level, hparams): translate_const = hparams["translate_const"] level = (level / _MAX_LEVEL) * float(translate_const) level = _randomly_negate(level) return (level,) def _translate_rel_level_to_arg(level, hparams): # default range [-0.45, 0.45] translate_pct = hparams.get("translate_pct", 0.45) level = (level / _MAX_LEVEL) * translate_pct level = _randomly_negate(level) return (level,) def _posterize_level_to_arg(level, _hparams): # As per Tensorflow TPU EfficientNet impl # range [0, 4], 'keep 0 up to 4 MSB of original image' # intensity/severity of augmentation decreases with level return (int((level / _MAX_LEVEL) * 4),) def _posterize_increasing_level_to_arg(level, hparams): # As per Tensorflow models research and UDA impl # range [4, 0], 'keep 4 down to 0 MSB of original image', # intensity/severity of augmentation increases with level return (4 - _posterize_level_to_arg(level, hparams)[0],) def _posterize_original_level_to_arg(level, _hparams): # As per original AutoAugment paper description # range [4, 8], 'keep 4 up to 8 MSB of image' # intensity/severity of augmentation decreases with level return (int((level / _MAX_LEVEL) * 4) + 4,) def _solarize_level_to_arg(level, _hparams): # range [0, 256] # intensity/severity of augmentation decreases with level return (int((level / _MAX_LEVEL) * 256),) def _solarize_increasing_level_to_arg(level, _hparams): # range [0, 256] # intensity/severity of augmentation increases with level return (256 - _solarize_level_to_arg(level, _hparams)[0],) def _solarize_add_level_to_arg(level, _hparams): # range [0, 110] return (int((level / _MAX_LEVEL) * 110),) LEVEL_TO_ARG = { "AutoContrast": None, "Equalize": None, "Invert": None, "Rotate": _rotate_level_to_arg, # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers "Posterize": _posterize_level_to_arg, "PosterizeIncreasing": _posterize_increasing_level_to_arg, "PosterizeOriginal": _posterize_original_level_to_arg, "Solarize": _solarize_level_to_arg, "SolarizeIncreasing": _solarize_increasing_level_to_arg, "SolarizeAdd": _solarize_add_level_to_arg, "Color": _enhance_level_to_arg, "ColorIncreasing": _enhance_increasing_level_to_arg, "Contrast": _enhance_level_to_arg, "ContrastIncreasing": _enhance_increasing_level_to_arg, "Brightness": _enhance_level_to_arg, "BrightnessIncreasing": _enhance_increasing_level_to_arg, "Sharpness": _enhance_level_to_arg, "SharpnessIncreasing": _enhance_increasing_level_to_arg, "ShearX": _shear_level_to_arg, "ShearY": _shear_level_to_arg, "TranslateX": _translate_abs_level_to_arg, "TranslateY": _translate_abs_level_to_arg, "TranslateXRel": _translate_rel_level_to_arg, "TranslateYRel": _translate_rel_level_to_arg, } NAME_TO_OP = { "AutoContrast": auto_contrast, "Equalize": equalize, "Invert": invert, "Rotate": rotate, "Posterize": posterize, "PosterizeIncreasing": posterize, "PosterizeOriginal": posterize, "Solarize": solarize, "SolarizeIncreasing": solarize, "SolarizeAdd": solarize_add, "Color": color, "ColorIncreasing": color, "Contrast": contrast, "ContrastIncreasing": contrast, "Brightness": brightness, "BrightnessIncreasing": brightness, "Sharpness": sharpness, "SharpnessIncreasing": sharpness, "ShearX": shear_x, "ShearY": shear_y, "TranslateX": translate_x_abs, "TranslateY": translate_y_abs, "TranslateXRel": translate_x_rel, "TranslateYRel": translate_y_rel, } class AugmentOp: """ Apply for video. """ def __init__(self, name, prob=0.5, magnitude=10, hparams=None): hparams = hparams or _HPARAMS_DEFAULT self.aug_fn = NAME_TO_OP[name] self.level_fn = LEVEL_TO_ARG[name] self.prob = prob self.magnitude = magnitude self.hparams = hparams.copy() self.kwargs = { "fillcolor": hparams["img_mean"] if "img_mean" in hparams else _FILL, "resample": hparams["interpolation"] if "interpolation" in hparams else _RANDOM_INTERPOLATION, } # If magnitude_std is > 0, we introduce some randomness # in the usually fixed policy and sample magnitude from a normal distribution # with mean `magnitude` and std-dev of `magnitude_std`. # NOTE This is my own hack, being tested, not in papers or reference impls. self.magnitude_std = self.hparams.get("magnitude_std", 0) def __call__(self, img_list): if self.prob < 1.0 and random.random() > self.prob: return img_list magnitude = self.magnitude if self.magnitude_std and self.magnitude_std > 0: magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range level_args = ( self.level_fn(magnitude, self.hparams) if self.level_fn is not None else () ) if isinstance(img_list, list): return [ self.aug_fn(img, *level_args, **self.kwargs) for img in img_list ] else: return self.aug_fn(img_list, *level_args, **self.kwargs) _RAND_TRANSFORMS = [ "AutoContrast", "Equalize", "Invert", "Rotate", "Posterize", "Solarize", "SolarizeAdd", "Color", "Contrast", "Brightness", "Sharpness", "ShearX", "ShearY", "TranslateXRel", "TranslateYRel", ] _RAND_INCREASING_TRANSFORMS = [ "AutoContrast", "Equalize", "Invert", "Rotate", "PosterizeIncreasing", "SolarizeIncreasing", "SolarizeAdd", "ColorIncreasing", "ContrastIncreasing", "BrightnessIncreasing", "SharpnessIncreasing", "ShearX", "ShearY", "TranslateXRel", "TranslateYRel", ] # These experimental weights are based loosely on the relative improvements mentioned in paper. # They may not result in increased performance, but could likely be tuned to so. _RAND_CHOICE_WEIGHTS_0 = { "Rotate": 0.3, "ShearX": 0.2, "ShearY": 0.2, "TranslateXRel": 0.1, "TranslateYRel": 0.1, "Color": 0.025, "Sharpness": 0.025, "AutoContrast": 0.025, "Solarize": 0.005, "SolarizeAdd": 0.005, "Contrast": 0.005, "Brightness": 0.005, "Equalize": 0.005, "Posterize": 0, "Invert": 0, } def _select_rand_weights(weight_idx=0, transforms=None): transforms = transforms or _RAND_TRANSFORMS assert weight_idx == 0 # only one set of weights currently rand_weights = _RAND_CHOICE_WEIGHTS_0 probs = [rand_weights[k] for k in transforms] probs /= np.sum(probs) return probs def rand_augment_ops(magnitude=10, hparams=None, transforms=None): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS return [ AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms ] class RandAugment: def __init__(self, ops, num_layers=2, choice_weights=None): self.ops = ops self.num_layers = num_layers self.choice_weights = choice_weights def __call__(self, img): # no replacement when using weighted choice ops = np.random.choice( self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights, ) for op in ops: img = op(img) return img def rand_augment_transform(config_str, hparams): """ RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 Create a RandAugment transform :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining sections, not order sepecific determine 'm' - integer magnitude of rand augment 'n' - integer num layers (number of transform ops selected per image) 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 'mstd' - float std deviation of magnitude noise applied 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 :param hparams: Other hparams (kwargs) for the RandAugmentation scheme :return: A PyTorch compatible Transform """ magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) num_layers = 2 # default to 2 ops per image weight_idx = None # default to no probability weights for op choice transforms = _RAND_TRANSFORMS config = config_str.split("-") assert config[0] == "rand" config = config[1:] for c in config: cs = re.split(r"(\d.*)", c) if len(cs) < 2: continue key, val = cs[:2] if key == "mstd": # noise param injected via hparams for now hparams.setdefault("magnitude_std", float(val)) elif key == "inc": if bool(val): transforms = _RAND_INCREASING_TRANSFORMS elif key == "m": magnitude = int(val) elif key == "n": num_layers = int(val) elif key == "w": weight_idx = int(val) else: assert NotImplementedError ra_ops = rand_augment_ops( magnitude=magnitude, hparams=hparams, transforms=transforms ) choice_weights = ( None if weight_idx is None else _select_rand_weights(weight_idx) ) return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) ================================================ FILE: video_dataset/random_erasing.py ================================================ #!/usr/bin/env python # Originates from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/random_erasing.py # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """ This implementation is based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py pulished under an Apache License 2.0. COMMENT FROM ORIGINAL: Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 Copyright Zhun Zhong & Liang Zheng Hacked together by / Copyright 2020 Ross Wightman """ import math import random import torch def _get_pixels( per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" ): # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # paths, flip the order so normal is run on CPU if this becomes a problem # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 if per_pixel: return torch.empty(patch_size, dtype=dtype, device=device).normal_() elif rand_color: return torch.empty( (patch_size[0], 1, 1), dtype=dtype, device=device ).normal_() else: return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) class RandomErasing: """Randomly selects a rectangle region in an image and erases its pixels. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf This variant of RandomErasing is intended to be applied to either a batch or single image tensor after it has been normalized by dataset mean and std. Args: probability: Probability that the Random Erasing operation will be performed. min_area: Minimum percentage of erased area wrt input image area. max_area: Maximum percentage of erased area wrt input image area. min_aspect: Minimum aspect ratio of erased area. mode: pixel color mode, one of 'const', 'rand', or 'pixel' 'const' - erase block is constant color of 0 for all channels 'rand' - erase block is same per-channel random (normal) color 'pixel' - erase block is per-pixel random (normal) color max_count: maximum number of erasing blocks per image, area per box is scaled by count. per-image count is randomly chosen between 1 and this value. """ def __init__( self, probability=0.5, min_area=0.02, max_area=1 / 3, min_aspect=0.3, max_aspect=None, mode="const", min_count=1, max_count=None, num_splits=0, device="cuda", cube=True, ): self.probability = probability self.min_area = min_area self.max_area = max_area max_aspect = max_aspect or 1 / min_aspect self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) self.min_count = min_count self.max_count = max_count or min_count self.num_splits = num_splits mode = mode.lower() self.rand_color = False self.per_pixel = False self.cube = cube if mode == "rand": self.rand_color = True # per block random normal elif mode == "pixel": self.per_pixel = True # per pixel random normal else: assert not mode or mode == "const" self.device = device def _erase(self, img, chan, img_h, img_w, dtype): if random.random() > self.probability: return area = img_h * img_w count = ( self.min_count if self.min_count == self.max_count else random.randint(self.min_count, self.max_count) ) for _ in range(count): for _ in range(10): target_area = ( random.uniform(self.min_area, self.max_area) * area / count ) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < img_w and h < img_h: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) img[:, top : top + h, left : left + w] = _get_pixels( self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=self.device, ) break def _erase_cube( self, img, batch_start, batch_size, chan, img_h, img_w, dtype, ): if random.random() > self.probability: return area = img_h * img_w count = ( self.min_count if self.min_count == self.max_count else random.randint(self.min_count, self.max_count) ) for _ in range(count): for _ in range(100): target_area = ( random.uniform(self.min_area, self.max_area) * area / count ) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < img_w and h < img_h: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) for i in range(batch_start, batch_size): img_instance = img[i] img_instance[ :, top : top + h, left : left + w ] = _get_pixels( self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=self.device, ) break def __call__(self, input): if len(input.size()) == 3: self._erase(input, *input.size(), input.dtype) else: batch_size, chan, img_h, img_w = input.size() # skip first slice of batch if num_splits is set (for clean portion of samples) batch_start = ( batch_size // self.num_splits if self.num_splits > 1 else 0 ) if self.cube: self._erase_cube( input, batch_start, batch_size, chan, img_h, img_w, input.dtype, ) else: for i in range(batch_start, batch_size): self._erase(input[i], chan, img_h, img_w, input.dtype) return input ================================================ FILE: video_dataset/transform.py ================================================ #!/usr/bin/env python3 # Originate from: https://github.com/facebookresearch/SlowFast/blob/fee19d699c49a81f33b890c5ff592bbb11aa5c54/slowfast/datasets/transform.py # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import logging import math import numpy as np # import cv2 import random import torch import torchvision as tv import torchvision.transforms.functional as F from PIL import Image, ImageFilter from torchvision import transforms from .rand_augment import rand_augment_transform from .random_erasing import RandomErasing _pil_interpolation_to_str = { Image.NEAREST: "PIL.Image.NEAREST", Image.BILINEAR: "PIL.Image.BILINEAR", Image.BICUBIC: "PIL.Image.BICUBIC", Image.LANCZOS: "PIL.Image.LANCZOS", Image.HAMMING: "PIL.Image.HAMMING", Image.BOX: "PIL.Image.BOX", } _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) def _pil_interp(method): if method == "bicubic": return Image.BICUBIC elif method == "lanczos": return Image.LANCZOS elif method == "hamming": return Image.HAMMING else: return Image.BILINEAR logger = logging.getLogger(__name__) def random_short_side_scale_jitter( images, min_size, max_size, boxes=None, inverse_uniform_sampling=False ): """ Perform a spatial short scale jittering on the given images and corresponding boxes. Args: images (tensor): images to perform scale jitter. Dimension is `num frames` x `channel` x `height` x `width`. min_size (int): the minimal size to scale the frames. max_size (int): the maximal size to scale the frames. boxes (ndarray): optional. Corresponding boxes to images. Dimension is `num boxes` x 4. inverse_uniform_sampling (bool): if True, sample uniformly in [1 / max_scale, 1 / min_scale] and take a reciprocal to get the scale. If False, take a uniform sample from [min_scale, max_scale]. Returns: (tensor): the scaled images with dimension of `num frames` x `channel` x `new height` x `new width`. (ndarray or None): the scaled boxes with dimension of `num boxes` x 4. """ if inverse_uniform_sampling: size = int( round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) ) else: size = int(round(np.random.uniform(min_size, max_size))) height = images.shape[2] width = images.shape[3] if (width <= height and width == size) or ( height <= width and height == size ): return images, boxes new_width = size new_height = size if width < height: new_height = int(math.floor((float(height) / width) * size)) if boxes is not None: boxes = boxes * float(new_height) / height else: new_width = int(math.floor((float(width) / height) * size)) if boxes is not None: boxes = boxes * float(new_width) / width return ( torch.nn.functional.interpolate( images, size=(new_height, new_width), mode="bilinear", align_corners=False, ), boxes, ) def crop_boxes(boxes, x_offset, y_offset): """ Peform crop on the bounding boxes given the offsets. Args: boxes (ndarray or None): bounding boxes to peform crop. The dimension is `num boxes` x 4. x_offset (int): cropping offset in the x axis. y_offset (int): cropping offset in the y axis. Returns: cropped_boxes (ndarray or None): the cropped boxes with dimension of `num boxes` x 4. """ cropped_boxes = boxes.copy() cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset return cropped_boxes def random_crop(images, size, boxes=None): """ Perform random spatial crop on the given images and corresponding boxes. Args: images (tensor): images to perform random crop. The dimension is `num frames` x `channel` x `height` x `width`. size (int): the size of height and width to crop on the image. boxes (ndarray or None): optional. Corresponding boxes to images. Dimension is `num boxes` x 4. Returns: cropped (tensor): cropped images with dimension of `num frames` x `channel` x `size` x `size`. cropped_boxes (ndarray or None): the cropped boxes with dimension of `num boxes` x 4. """ if images.shape[2] == size and images.shape[3] == size: return images, boxes height = images.shape[2] width = images.shape[3] y_offset = 0 if height > size: y_offset = int(np.random.randint(0, height - size)) x_offset = 0 if width > size: x_offset = int(np.random.randint(0, width - size)) cropped = images[ :, :, y_offset : y_offset + size, x_offset : x_offset + size ] cropped_boxes = ( crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None ) return cropped, cropped_boxes def horizontal_flip(prob, images, boxes=None): """ Perform horizontal flip on the given images and corresponding boxes. Args: prob (float): probility to flip the images. images (tensor): images to perform horizontal flip, the dimension is `num frames` x `channel` x `height` x `width`. boxes (ndarray or None): optional. Corresponding boxes to images. Dimension is `num boxes` x 4. Returns: images (tensor): images with dimension of `num frames` x `channel` x `height` x `width`. flipped_boxes (ndarray or None): the flipped boxes with dimension of `num boxes` x 4. """ if boxes is None: flipped_boxes = None else: flipped_boxes = boxes.copy() if np.random.uniform() < prob: images = images.flip((-1)) if len(images.shape) == 3: width = images.shape[2] elif len(images.shape) == 4: width = images.shape[3] else: raise NotImplementedError("Dimension does not supported") if boxes is not None: flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 return images, flipped_boxes def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): """ Perform uniform spatial sampling on the images and corresponding boxes. Args: images (tensor): images to perform uniform crop. The dimension is `num frames` x `channel` x `height` x `width`. size (int): size of height and weight to crop the images. spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width is larger than height. Or 0, 1, or 2 for top, center, and bottom crop if height is larger than width. boxes (ndarray or None): optional. Corresponding boxes to images. Dimension is `num boxes` x 4. scale_size (int): optinal. If not None, resize the images to scale_size before performing any crop. Returns: cropped (tensor): images with dimension of `num frames` x `channel` x `size` x `size`. cropped_boxes (ndarray or None): the cropped boxes with dimension of `num boxes` x 4. """ assert spatial_idx in [0, 1, 2] ndim = len(images.shape) if ndim == 3: images = images.unsqueeze(0) height = images.shape[2] width = images.shape[3] if scale_size is not None: if width <= height: width, height = scale_size, int(height / width * scale_size) else: width, height = int(width / height * scale_size), scale_size images = torch.nn.functional.interpolate( images, size=(height, width), mode="bilinear", align_corners=False, ) y_offset = int(math.ceil((height - size) / 2)) x_offset = int(math.ceil((width - size) / 2)) if height > width: if spatial_idx == 0: y_offset = 0 elif spatial_idx == 2: y_offset = height - size else: if spatial_idx == 0: x_offset = 0 elif spatial_idx == 2: x_offset = width - size cropped = images[ :, :, y_offset : y_offset + size, x_offset : x_offset + size ] cropped_boxes = ( crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None ) if ndim == 3: cropped = cropped.squeeze(0) return cropped, cropped_boxes def clip_boxes_to_image(boxes, height, width): """ Clip an array of boxes to an image with the given height and width. Args: boxes (ndarray): bounding boxes to perform clipping. Dimension is `num boxes` x 4. height (int): given image height. width (int): given image width. Returns: clipped_boxes (ndarray): the clipped boxes with dimension of `num boxes` x 4. """ clipped_boxes = boxes.copy() clipped_boxes[:, [0, 2]] = np.minimum( width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) ) clipped_boxes[:, [1, 3]] = np.minimum( height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) ) return clipped_boxes def blend(images1, images2, alpha): """ Blend two images with a given weight alpha. Args: images1 (tensor): the first images to be blended, the dimension is `num frames` x `channel` x `height` x `width`. images2 (tensor): the second images to be blended, the dimension is `num frames` x `channel` x `height` x `width`. alpha (float): the blending weight. Returns: (tensor): blended images, the dimension is `num frames` x `channel` x `height` x `width`. """ return images1 * alpha + images2 * (1 - alpha) def grayscale(images): """ Get the grayscale for the input images. The channels of images should be in order BGR. Args: images (tensor): the input images for getting grayscale. Dimension is `num frames` x `channel` x `height` x `width`. Returns: img_gray (tensor): blended images, the dimension is `num frames` x `channel` x `height` x `width`. """ # R -> 0.299, G -> 0.587, B -> 0.114. img_gray = torch.tensor(images) gray_channel = ( 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] ) img_gray[:, 0] = gray_channel img_gray[:, 1] = gray_channel img_gray[:, 2] = gray_channel return img_gray def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): """ Perfrom a color jittering on the input images. The channels of images should be in order BGR. Args: images (tensor): images to perform color jitter. Dimension is `num frames` x `channel` x `height` x `width`. img_brightness (float): jitter ratio for brightness. img_contrast (float): jitter ratio for contrast. img_saturation (float): jitter ratio for saturation. Returns: images (tensor): the jittered images, the dimension is `num frames` x `channel` x `height` x `width`. """ jitter = [] if img_brightness != 0: jitter.append("brightness") if img_contrast != 0: jitter.append("contrast") if img_saturation != 0: jitter.append("saturation") if len(jitter) > 0: order = np.random.permutation(np.arange(len(jitter))) for idx in range(0, len(jitter)): if jitter[order[idx]] == "brightness": images = brightness_jitter(img_brightness, images) elif jitter[order[idx]] == "contrast": images = contrast_jitter(img_contrast, images) elif jitter[order[idx]] == "saturation": images = saturation_jitter(img_saturation, images) return images def brightness_jitter(var, images): """ Perfrom brightness jittering on the input images. The channels of images should be in order BGR. Args: var (float): jitter ratio for brightness. images (tensor): images to perform color jitter. Dimension is `num frames` x `channel` x `height` x `width`. Returns: images (tensor): the jittered images, the dimension is `num frames` x `channel` x `height` x `width`. """ alpha = 1.0 + np.random.uniform(-var, var) img_bright = torch.zeros(images.shape) images = blend(images, img_bright, alpha) return images def contrast_jitter(var, images): """ Perfrom contrast jittering on the input images. The channels of images should be in order BGR. Args: var (float): jitter ratio for contrast. images (tensor): images to perform color jitter. Dimension is `num frames` x `channel` x `height` x `width`. Returns: images (tensor): the jittered images, the dimension is `num frames` x `channel` x `height` x `width`. """ alpha = 1.0 + np.random.uniform(-var, var) img_gray = grayscale(images) img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) images = blend(images, img_gray, alpha) return images def saturation_jitter(var, images): """ Perfrom saturation jittering on the input images. The channels of images should be in order BGR. Args: var (float): jitter ratio for saturation. images (tensor): images to perform color jitter. Dimension is `num frames` x `channel` x `height` x `width`. Returns: images (tensor): the jittered images, the dimension is `num frames` x `channel` x `height` x `width`. """ alpha = 1.0 + np.random.uniform(-var, var) img_gray = grayscale(images) images = blend(images, img_gray, alpha) return images def lighting_jitter(images, alphastd, eigval, eigvec): """ Perform AlexNet-style PCA jitter on the given images. Args: images (tensor): images to perform lighting jitter. Dimension is `num frames` x `channel` x `height` x `width`. alphastd (float): jitter ratio for PCA jitter. eigval (list): eigenvalues for PCA jitter. eigvec (list[list]): eigenvectors for PCA jitter. Returns: out_images (tensor): the jittered images, the dimension is `num frames` x `channel` x `height` x `width`. """ if alphastd == 0: return images # generate alpha1, alpha2, alpha3. alpha = np.random.normal(0, alphastd, size=(1, 3)) eig_vec = np.array(eigvec) eig_val = np.reshape(eigval, (1, 3)) rgb = np.sum( eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), axis=1, ) out_images = torch.zeros_like(images) if len(images.shape) == 3: # C H W channel_dim = 0 elif len(images.shape) == 4: # T C H W channel_dim = 1 else: raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") for idx in range(images.shape[channel_dim]): # C H W if len(images.shape) == 3: out_images[idx] = images[idx] + rgb[2 - idx] # T C H W elif len(images.shape) == 4: out_images[:, idx] = images[:, idx] + rgb[2 - idx] else: raise NotImplementedError( f"Unsupported dimension {len(images.shape)}" ) return out_images def color_normalization(images, mean, stddev): """ Perform color nomration on the given images. Args: images (tensor): images to perform color normalization. Dimension is `num frames` x `channel` x `height` x `width`. mean (list): mean values for normalization. stddev (list): standard deviations for normalization. Returns: out_images (tensor): the noramlized images, the dimension is `num frames` x `channel` x `height` x `width`. """ if len(images.shape) == 3: assert ( len(mean) == images.shape[0] ), "channel mean not computed properly" assert ( len(stddev) == images.shape[0] ), "channel stddev not computed properly" elif len(images.shape) == 4: assert ( len(mean) == images.shape[1] ), "channel mean not computed properly" assert ( len(stddev) == images.shape[1] ), "channel stddev not computed properly" else: raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") out_images = torch.zeros_like(images) for idx in range(len(mean)): # C H W if len(images.shape) == 3: out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] elif len(images.shape) == 4: out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] else: raise NotImplementedError( f"Unsupported dimension {len(images.shape)}" ) return out_images def _get_param_spatial_crop( scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False ): """ Given scale, ratio, height and width, return sampled coordinates of the videos. """ for _ in range(num_repeat): area = height * width target_area = random.uniform(*scale) * area if log_scale: log_ratio = (math.log(ratio[0]), math.log(ratio[1])) aspect_ratio = math.exp(random.uniform(*log_ratio)) else: aspect_ratio = random.uniform(*ratio) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if np.random.uniform() < 0.5 and switch_hw: w, h = h, w if 0 < w <= width and 0 < h <= height: i = random.randint(0, height - h) j = random.randint(0, width - w) return i, j, h, w # Fallback to central crop in_ratio = float(width) / float(height) if in_ratio < min(ratio): w = width h = int(round(w / min(ratio))) elif in_ratio > max(ratio): h = height w = int(round(h * max(ratio))) else: # whole image w = width h = height i = (height - h) // 2 j = (width - w) // 2 return i, j, h, w def random_resized_crop( images, target_height, target_width, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), ): """ Crop the given images to random size and aspect ratio. A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size. This is popularly used to train the Inception networks. Args: images: Images to perform resizing and cropping. target_height: Desired height after cropping. target_width: Desired width after cropping. scale: Scale range of Inception-style area based random resizing. ratio: Aspect ratio range of Inception-style area based random resizing. """ height = images.shape[2] width = images.shape[3] i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) cropped = images[:, :, i : i + h, j : j + w] return torch.nn.functional.interpolate( cropped, size=(target_height, target_width), mode="bilinear", align_corners=False, ) def random_resized_crop_with_shift( images, target_height, target_width, scale=(0.8, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), ): """ This is similar to random_resized_crop. However, it samples two different boxes (for cropping) for the first and last frame. It then linearly interpolates the two boxes for other frames. Args: images: Images to perform resizing and cropping. target_height: Desired height after cropping. target_width: Desired width after cropping. scale: Scale range of Inception-style area based random resizing. ratio: Aspect ratio range of Inception-style area based random resizing. """ t = images.shape[1] height = images.shape[2] width = images.shape[3] i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] out = torch.zeros((3, t, target_height, target_width)) for ind in range(t): out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate( images[ :, ind : ind + 1, i_s[ind] : i_s[ind] + h_s[ind], j_s[ind] : j_s[ind] + w_s[ind], ], size=(target_height, target_width), mode="bilinear", align_corners=False, ) return out def create_random_augment( input_size, auto_augment=None, interpolation="bilinear", ): """ Get video randaug transform. Args: input_size: The size of the input video in tuple. auto_augment: Parameters for randaug. An example: "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number of operations to apply). interpolation: Interpolation method. """ if isinstance(input_size, tuple): img_size = input_size[-2:] else: img_size = input_size if auto_augment: assert isinstance(auto_augment, str) if isinstance(img_size, tuple): img_size_min = min(img_size) else: img_size_min = img_size aa_params = {"translate_const": int(img_size_min * 0.45)} if interpolation and interpolation != "random": aa_params["interpolation"] = _pil_interp(interpolation) if auto_augment.startswith("rand"): return transforms.Compose( [rand_augment_transform(auto_augment, aa_params)] ) raise NotImplementedError def random_sized_crop_img( im, size, jitter_scale=(0.08, 1.0), jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), max_iter=10, ): """ Performs Inception-style cropping (used for training). """ assert ( len(im.shape) == 3 ), "Currently only support image for random_sized_crop" h, w = im.shape[1:3] i, j, h, w = _get_param_spatial_crop( scale=jitter_scale, ratio=jitter_aspect, height=h, width=w, num_repeat=max_iter, log_scale=False, switch_hw=True, ) cropped = im[:, i : i + h, j : j + w] return torch.nn.functional.interpolate( cropped.unsqueeze(0), size=(size, size), mode="bilinear", align_corners=False, ).squeeze(0) # The following code are modified based on timm lib, we will replace the following # contents with dependency from PyTorchVideo. # https://github.com/facebookresearch/pytorchvideo class RandomResizedCropAndInterpolation: """Crop the given PIL Image to random size and aspect ratio with random interpolation. A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size. This is popularly used to train the Inception networks. Args: size: expected output size of each edge scale: range of size of the origin size cropped ratio: range of aspect ratio of the origin aspect ratio cropped interpolation: Default: PIL.Image.BILINEAR """ def __init__( self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation="bilinear", ): if isinstance(size, tuple): self.size = size else: self.size = (size, size) if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): print("range should be of kind (min, max)") if interpolation == "random": self.interpolation = _RANDOM_INTERPOLATION else: self.interpolation = _pil_interp(interpolation) self.scale = scale self.ratio = ratio @staticmethod def get_params(img, scale, ratio): """Get parameters for ``crop`` for a random sized crop. Args: img (PIL Image): Image to be cropped. scale (tuple): range of size of the origin size cropped ratio (tuple): range of aspect ratio of the origin aspect ratio cropped Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ area = img.size[0] * img.size[1] for _ in range(10): target_area = random.uniform(*scale) * area log_ratio = (math.log(ratio[0]), math.log(ratio[1])) aspect_ratio = math.exp(random.uniform(*log_ratio)) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if w <= img.size[0] and h <= img.size[1]: i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w # Fallback to central crop in_ratio = img.size[0] / img.size[1] if in_ratio < min(ratio): w = img.size[0] h = int(round(w / min(ratio))) elif in_ratio > max(ratio): h = img.size[1] w = int(round(h * max(ratio))) else: # whole image w = img.size[0] h = img.size[1] i = (img.size[1] - h) // 2 j = (img.size[0] - w) // 2 return i, j, h, w def __call__(self, img): """ Args: img (PIL Image): Image to be cropped and resized. Returns: PIL Image: Randomly cropped and resized image. """ i, j, h, w = self.get_params(img, self.scale, self.ratio) if isinstance(self.interpolation, (tuple, list)): interpolation = random.choice(self.interpolation) else: interpolation = self.interpolation return F.resized_crop(img, i, j, h, w, self.size, interpolation) def __repr__(self): if isinstance(self.interpolation, (tuple, list)): interpolate_str = " ".join( [_pil_interpolation_to_str[x] for x in self.interpolation] ) else: interpolate_str = _pil_interpolation_to_str[self.interpolation] format_string = self.__class__.__name__ + "(size={0}".format(self.size) format_string += ", scale={0}".format( tuple(round(s, 4) for s in self.scale) ) format_string += ", ratio={0}".format( tuple(round(r, 4) for r in self.ratio) ) format_string += ", interpolation={0})".format(interpolate_str) return format_string ================================================ FILE: vision_transformer.py ================================================ #!/usr/bin/env python from collections import OrderedDict import numpy as np from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F ''' QuickGELU and LayerNorm w/ fp16 from official CLIP repo (https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py) ''' class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" def forward(self, x: torch.Tensor): orig_type = x.dtype ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) class Attention(nn.Module): ''' A generalized attention module with more flexibility. ''' def __init__( self, q_in_dim: int, k_in_dim: int, v_in_dim: int, qk_proj_dim: int, v_proj_dim: int, num_heads: int, out_dim: int, return_all_features: bool = False, ): super().__init__() self.q_proj = nn.Linear(q_in_dim, qk_proj_dim) self.k_proj = nn.Linear(k_in_dim, qk_proj_dim) self.v_proj = nn.Linear(v_in_dim, v_proj_dim) self.out_proj = nn.Linear(v_proj_dim, out_dim) self.num_heads = num_heads self.return_all_features = return_all_features assert qk_proj_dim % num_heads == 0 and v_proj_dim % num_heads == 0 self._initialize_weights() def _initialize_weights(self): for m in (self.q_proj, self.k_proj, self.v_proj, self.out_proj): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0.) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3 N = q.size(0); assert k.size(0) == N and v.size(0) == N Lq, Lkv = q.size(1), k.size(1); assert v.size(1) == Lkv q, k, v = self.q_proj(q), self.k_proj(k), self.v_proj(v) H = self.num_heads Cqk, Cv = q.size(-1) // H, v.size(-1) // H q = q.view(N, Lq, H, Cqk) k = k.view(N, Lkv, H, Cqk) v = v.view(N, Lkv, H, Cv) aff = torch.einsum('nqhc,nkhc->nqkh', q / (Cqk ** 0.5), k) aff = aff.softmax(dim=-2) mix = torch.einsum('nqlh,nlhc->nqhc', aff, v) out = self.out_proj(mix.flatten(-2)) if self.return_all_features: return dict(q=q, k=k, v=v, aff=aff, out=out) else: return out class PatchEmbed2D(nn.Module): def __init__( self, patch_size: Tuple[int, int] = (16, 16), in_channels: int = 3, embed_dim: int = 768, ): super().__init__() self.patch_size = patch_size self.in_channels = in_channels self.proj = nn.Linear(np.prod(patch_size) * in_channels, embed_dim) def _initialize_weights(self, x): nn.init.kaiming_normal_(self.proj.weight, 0.) nn.init.constant_(self.proj.bias, 0.) def forward(self, x: torch.Tensor): B, C, H, W = x.size() pH, pW = self.patch_size assert C == self.in_channels and H % pH == 0 and W % pW == 0 x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3).flatten(1, 2) x = self.proj(x) return x class TransformerEncoderLayer(nn.Module): def __init__( self, in_feature_dim: int = 768, qkv_dim: int = 768, num_heads: int = 12, mlp_factor: float = 4.0, mlp_dropout: float = 0.0, act: nn.Module = QuickGELU, return_all_features: bool = False, ): super().__init__() self.return_all_features = return_all_features self.attn = Attention( q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim, qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim, return_all_features=return_all_features, ) mlp_dim = round(mlp_factor * in_feature_dim) self.mlp = nn.Sequential(OrderedDict([ ('fc1', nn.Linear(in_feature_dim, mlp_dim)), ('act', act()), ('dropout', nn.Dropout(mlp_dropout)), ('fc2', nn.Linear(mlp_dim, in_feature_dim)), ])) self.norm1 = LayerNorm(in_feature_dim) self.norm2 = LayerNorm(in_feature_dim) self._initialize_weights() def _initialize_weights(self): for m in (self.mlp[0], self.mlp[-1]): nn.init.xavier_uniform_(m.weight) nn.init.normal_(m.bias, std=1e-6) def forward(self, x: torch.Tensor): if self.return_all_features: ret_dict = {} x_norm = self.norm1(x) attn_out = self.attn(x_norm, x_norm, x_norm) ret_dict['q'] = attn_out['q'] ret_dict['k'] = attn_out['k'] ret_dict['v'] = attn_out['v'] ret_dict['attn_out'] = attn_out['out'] x = x + attn_out['out'] x = x + self.mlp(self.norm2(x)) ret_dict['out'] = x return ret_dict else: x_norm = self.norm1(x) x = x + self.attn(x_norm, x_norm, x_norm) x = x + self.mlp(self.norm2(x)) return x class TransformerDecoderLayer(nn.Module): def __init__( self, in_feature_dim: int = 768, qkv_dim: int = 768, num_heads: int = 12, mlp_factor: float = 4.0, mlp_dropout: float = 0.0, act: nn.Module = QuickGELU, ): super().__init__() self.attn = Attention( q_in_dim=in_feature_dim, k_in_dim=in_feature_dim, v_in_dim=in_feature_dim, qk_proj_dim=qkv_dim, v_proj_dim=qkv_dim, num_heads=num_heads, out_dim=in_feature_dim, ) mlp_dim = round(mlp_factor * in_feature_dim) self.mlp = nn.Sequential(OrderedDict([ ('fc1', nn.Linear(in_feature_dim, mlp_dim)), ('act', act()), ('dropout', nn.Dropout(mlp_dropout)), ('fc2', nn.Linear(mlp_dim, in_feature_dim)), ])) self.norm1 = LayerNorm(in_feature_dim) self.norm2 = LayerNorm(in_feature_dim) self.norm3 = LayerNorm(in_feature_dim) self._initialize_weights() def _initialize_weights(self): for m in (self.mlp[0], self.mlp[-1]): nn.init.xavier_uniform_(m.weight) nn.init.normal_(m.bias, std=1e-6) def forward(self, x: torch.Tensor, y: torch.Tensor): y_norm = self.norm3(y) x = x + self.attn(self.norm1(x), y_norm, y_norm) x = x + self.mlp(self.norm2(x)) return x class VisionTransformer2D(nn.Module): def __init__( self, feature_dim: int = 768, input_size: Tuple[int, int] = (224, 224), patch_size: Tuple[int, int] = (16, 16), num_heads: int = 12, num_layers: int = 12, mlp_factor: float = 4.0, act: nn.Module = QuickGELU, return_all_features: bool = False, ln_pre: bool = False, ): super().__init__() self.return_all_features = return_all_features self.patch_embed = PatchEmbed2D(patch_size=patch_size, embed_dim=feature_dim) self.num_patches = np.prod([x // y for x, y in zip(input_size, patch_size)]) + 1 self.cls_token = nn.Parameter(torch.zeros([feature_dim])) self.pos_embed = nn.Parameter(torch.zeros([self.num_patches, feature_dim])) self.blocks = nn.ModuleList([ TransformerEncoderLayer( in_feature_dim=feature_dim, qkv_dim=feature_dim, num_heads=num_heads, mlp_factor=mlp_factor, act=act, return_all_features=return_all_features, ) for _ in range(num_layers) ]) if ln_pre: self.ln_pre = LayerNorm(feature_dim) else: self.ln_pre = nn.Identity() self._initialize_weights() def _initialize_weights(self): nn.init.normal_(self.cls_token, std=0.02) nn.init.normal_(self.pos_embed, std=0.02) def forward(self, x: torch.Tensor): dtype = self.patch_embed.proj.weight.dtype x = x.to(dtype) x = self.patch_embed(x) x = torch.cat([self.cls_token.view(1, 1, -1).repeat(x.size(0), 1, 1), x], dim=1) x = x + self.pos_embed x = self.ln_pre(x) if self.return_all_features: all_features = [] for blk in self.blocks: x = blk(x) all_features.append(x) x = x['out'] return all_features else: for blk in self.blocks: x = blk(x) return x def model_to_fp16(model: VisionTransformer2D): def _module_to_fp16(m: nn.Module): if isinstance(m, (nn.Linear,)): m.half() model.apply(_module_to_fp16) model.pos_embed.data = model.pos_embed.data.half() model.cls_token.data = model.cls_token.data.half() vit_presets = { 'ViT-B/16-lnpre': dict( feature_dim=768, input_size=(224, 224), patch_size=(16, 16), num_heads=12, num_layers=12, mlp_factor=4.0, ln_pre=True, ), 'ViT-L/14-lnpre': dict( feature_dim=1024, input_size=(224, 224), patch_size=(14, 14), num_heads=16, num_layers=24, mlp_factor=4.0, ln_pre=True, ), } ================================================ FILE: weight_loaders.py ================================================ #!/usr/bin/env python import os, sys from typing import Dict import torch __all__ = ['weight_loader_fn_dict'] def load_weights_clip(load_path: str) -> Dict[str, torch.Tensor]: clip_model = torch.jit.load(load_path, map_location='cpu') clip_model = clip_model.visual src_state_dict = clip_model.state_dict() src_state_dict = dict((k, v.float()) for k, v in src_state_dict.items()) dst_state_dict = {} dst_state_dict['cls_token'] = src_state_dict['class_embedding'] dst_state_dict['pos_embed'] = src_state_dict['positional_embedding'] dst_state_dict['patch_embed.proj.weight'] = src_state_dict['conv1.weight'].flatten(1) dst_state_dict['patch_embed.proj.bias'] = torch.zeros([src_state_dict['conv1.weight'].size(0)]) dst_state_dict['ln_pre.weight'] = src_state_dict['ln_pre.weight'] dst_state_dict['ln_pre.bias'] = src_state_dict['ln_pre.bias'] block_idx = 0 while True: src_prefix = 'transformer.resblocks.%d.' % block_idx dst_prefix = 'blocks.%d.' % block_idx src_block_state_dict = dict((k[len(src_prefix):], v) for k, v in src_state_dict.items() if k.startswith(src_prefix)) if len(src_block_state_dict) == 0: break dst_block_state_dict = {} feat_dim = src_block_state_dict['ln_1.weight'].size(0) for i, dst_name in enumerate(('q', 'k', 'v')): dst_block_state_dict['attn.%s_proj.weight' % dst_name] = src_block_state_dict['attn.in_proj_weight'][feat_dim * i: feat_dim * (i + 1)] dst_block_state_dict['attn.%s_proj.bias' % dst_name] = src_block_state_dict['attn.in_proj_bias'][feat_dim * i: feat_dim * (i + 1)] dst_block_state_dict['attn.out_proj.weight'] = src_block_state_dict['attn.out_proj.weight'] dst_block_state_dict['attn.out_proj.bias'] = src_block_state_dict['attn.out_proj.bias'] dst_block_state_dict['mlp.fc1.weight'] = src_block_state_dict['mlp.c_fc.weight'] dst_block_state_dict['mlp.fc1.bias'] = src_block_state_dict['mlp.c_fc.bias'] dst_block_state_dict['mlp.fc2.weight'] = src_block_state_dict['mlp.c_proj.weight'] dst_block_state_dict['mlp.fc2.bias'] = src_block_state_dict['mlp.c_proj.bias'] dst_block_state_dict['norm1.weight'] = src_block_state_dict['ln_1.weight'] dst_block_state_dict['norm1.bias'] = src_block_state_dict['ln_1.bias'] dst_block_state_dict['norm2.weight'] = src_block_state_dict['ln_2.weight'] dst_block_state_dict['norm2.bias'] = src_block_state_dict['ln_2.bias'] dst_state_dict.update(dict((dst_prefix + k, v) for k, v in dst_block_state_dict.items())) block_idx += 1 return dst_state_dict weight_loader_fn_dict = { 'clip': load_weights_clip, }