Repository: wjchaoGit/Group-Activity-Recognition Branch: master Commit: c8c1dff953d6 Files: 19 Total size: 23.9 MB Directory structure: gitextract_8dab1m4n/ ├── README.md ├── backbone.py ├── base_model.py ├── collective.py ├── config.py ├── data/ │ ├── collective/ │ │ └── tracks/ │ │ ├── readTracks.m │ │ └── showTracks.m │ └── volleyball/ │ ├── src_image_size.pkl │ └── tracks_normalized.pkl ├── dataset.py ├── gcn_model.py ├── result/ │ └── .gitkeep ├── scripts/ │ ├── train_collective_stage1.py │ ├── train_collective_stage2.py │ ├── train_volleyball_stage1.py │ └── train_volleyball_stage2.py ├── train_net.py ├── utils.py └── volleyball.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # Learning Actor Relation Graphs for Group Activity Recognition Source code for the following paper([arXiv link](https://arxiv.org/abs/1904.10117)): Learning Actor Relation Graphs for Group Activity Recognition Jianchao Wu, Limin Wang, Li Wang, Jie Guo, Gangshan Wu in CVPR 2019 ## Dependencies - Python `3.x` - PyTorch `0.4.1` - numpy, pickle, scikit-image - [RoIAlign for Pytorch](https://github.com/longcw/RoIAlign.pytorch) - Datasets: [Volleyball](https://github.com/mostafa-saad/deep-activity-rec), [Collective](http://vhosts.eecs.umich.edu/vision//activity-dataset.html) ## Prepare Datasets 1. Download [volleyball](http://vml.cs.sfu.ca/wp-content/uploads/volleyballdataset/volleyball.zip) or [collective](http://vhosts.eecs.umich.edu/vision//ActivityDataset.zip) dataset file. 2. Unzip the dataset file into `data/volleyball` or `data/collective`. ## Get Started 1. Stage1: Fine-tune the model on single frame without using GCN. ```shell # volleyball dataset python scripts/train_volleyball_stage1.py # collective dataset python scripts/train_collective_stage1.py ``` 2. Stage2: Fix weights of the feature extraction part of network, and train the network with GCN. ```shell # volleyball dataset python scripts/train_volleyball_stage2.py # collective dataset python scripts/train_collective_stage2.py ``` You can specify the running arguments in the python files under `scripts/` directory. The meanings of arguments can be found in `config.py` ## Citation ``` @inproceedings{CVPR2019_ARG, title = {Learning Actor Relation Graphs for Group Activity Recognition}, author = {Jianchao Wu and Limin Wang and Li Wang and Jie Guo and Gangshan Wu}, booktitle = {CVPR}, year = {2019}, } ``` ================================================ FILE: backbone.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class MyInception_v3(nn.Module): def __init__(self,transform_input=False,pretrained=False): super(MyInception_v3,self).__init__() self.transform_input=transform_input inception=models.inception_v3(pretrained=pretrained) self.Conv2d_1a_3x3 = inception.Conv2d_1a_3x3 self.Conv2d_2a_3x3 = inception.Conv2d_2a_3x3 self.Conv2d_2b_3x3 = inception.Conv2d_2b_3x3 self.Conv2d_3b_1x1 = inception.Conv2d_3b_1x1 self.Conv2d_4a_3x3 = inception.Conv2d_4a_3x3 self.Mixed_5b = inception.Mixed_5b self.Mixed_5c = inception.Mixed_5c self.Mixed_5d = inception.Mixed_5d self.Mixed_6a = inception.Mixed_6a self.Mixed_6b = inception.Mixed_6b self.Mixed_6c = inception.Mixed_6c self.Mixed_6d = inception.Mixed_6d self.Mixed_6e = inception.Mixed_6e def forward(self,x): outputs=[] if self.transform_input: x = x.clone() x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 # 299 x 299 x 3 x = self.Conv2d_1a_3x3(x) # 149 x 149 x 32 x = self.Conv2d_2a_3x3(x) # 147 x 147 x 32 x = self.Conv2d_2b_3x3(x) # 147 x 147 x 64 x = F.max_pool2d(x, kernel_size=3, stride=2) # 73 x 73 x 64 x = self.Conv2d_3b_1x1(x) # 73 x 73 x 80 x = self.Conv2d_4a_3x3(x) # 71 x 71 x 192 x = F.max_pool2d(x, kernel_size=3, stride=2) # 35 x 35 x 192 x = self.Mixed_5b(x) # 35 x 35 x 256 x = self.Mixed_5c(x) # 35 x 35 x 288 x = self.Mixed_5d(x) # 35 x 35 x 288 outputs.append(x) x = self.Mixed_6a(x) # 17 x 17 x 768 x = self.Mixed_6b(x) # 17 x 17 x 768 x = self.Mixed_6c(x) # 17 x 17 x 768 x = self.Mixed_6d(x) # 17 x 17 x 768 x = self.Mixed_6e(x) # 17 x 17 x 768 outputs.append(x) return outputs class MyVGG16(nn.Module): def __init__(self,pretrained=False): super(MyVGG16,self).__init__() vgg=models.vgg16(pretrained=pretrained) self.features=vgg.features def forward(self,x): x=self.features(x) return [x] class MyVGG19(nn.Module): def __init__(self,pretrained=False): super(MyVGG19,self).__init__() vgg=models.vgg19(pretrained=pretrained) self.features=vgg.features def forward(self,x): x=self.features(x) return [x] ================================================ FILE: base_model.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from backbone import * from utils import * from roi_align.roi_align import RoIAlign # RoIAlign module from roi_align.roi_align import CropAndResize # crop_and_resize module class Basenet_volleyball(nn.Module): """ main module of base model for the volleyball """ def __init__(self, cfg): super(Basenet_volleyball, self).__init__() self.cfg=cfg NFB=self.cfg.num_features_boxes D=self.cfg.emb_features K=self.cfg.crop_size[0] if cfg.backbone=='inv3': self.backbone=MyInception_v3(transform_input=False,pretrained=True) elif cfg.backbone=='vgg16': self.backbone=MyVGG16(pretrained=True) elif cfg.backbone=='vgg19': self.backbone=MyVGG19(pretrained=True) else: assert False self.roi_align=RoIAlign(*self.cfg.crop_size) self.fc_emb = nn.Linear(K*K*D,NFB) self.dropout_emb = nn.Dropout(p=self.cfg.train_dropout_prob) self.fc_actions=nn.Linear(NFB,self.cfg.num_actions) self.fc_activities=nn.Linear(NFB,self.cfg.num_activities) for m in self.modules(): if isinstance(m,nn.Linear): nn.init.kaiming_normal_(m.weight) nn.init.zeros_(m.bias) def savemodel(self,filepath): state = { 'backbone_state_dict': self.backbone.state_dict(), 'fc_emb_state_dict':self.fc_emb.state_dict(), 'fc_actions_state_dict':self.fc_actions.state_dict(), 'fc_activities_state_dict':self.fc_activities.state_dict() } torch.save(state, filepath) print('model saved to:',filepath) def loadmodel(self,filepath): state = torch.load(filepath) self.backbone.load_state_dict(state['backbone_state_dict']) self.fc_emb.load_state_dict(state['fc_emb_state_dict']) self.fc_actions.load_state_dict(state['fc_actions_state_dict']) self.fc_activities.load_state_dict(state['fc_activities_state_dict']) print('Load model states from: ',filepath) def forward(self,batch_data): images_in, boxes_in = batch_data # read config parameters B=images_in.shape[0] T=images_in.shape[1] H, W=self.cfg.image_size OH, OW=self.cfg.out_size N=self.cfg.num_boxes NFB=self.cfg.num_features_boxes # Reshape the input data images_in_flat=torch.reshape(images_in,(B*T,3,H,W)) #B*T, 3, H, W boxes_in_flat=torch.reshape(boxes_in,(B*T*N,4)) #B*T*N, 4 boxes_idx=[i * torch.ones(N, dtype=torch.int) for i in range(B*T) ] boxes_idx=torch.stack(boxes_idx).to(device=boxes_in.device) # B*T, N boxes_idx_flat=torch.reshape(boxes_idx,(B*T*N,)) #B*T*N, # Use backbone to extract features of images_in # Pre-precess first images_in_flat=prep_images(images_in_flat) outputs=self.backbone(images_in_flat) # Build multiscale features features_multiscale=[] for features in outputs: if features.shape[2:4]!=torch.Size([OH,OW]): features=F.interpolate(features,size=(OH,OW),mode='bilinear',align_corners=True) features_multiscale.append(features) features_multiscale=torch.cat(features_multiscale,dim=1) #B*T, D, OH, OW # ActNet boxes_in_flat.requires_grad=False boxes_idx_flat.requires_grad=False # features_multiscale.requires_grad=False # RoI Align boxes_features=self.roi_align(features_multiscale, boxes_in_flat, boxes_idx_flat) #B*T*N, D, K, K, boxes_features=boxes_features.reshape(B*T*N,-1) # B*T*N, D*K*K # Embedding to hidden state boxes_features=self.fc_emb(boxes_features) # B*T*N, NFB boxes_features=F.relu(boxes_features) boxes_features=self.dropout_emb(boxes_features) boxes_states=boxes_features.reshape(B,T,N,NFB) # Predict actions boxes_states_flat=boxes_states.reshape(-1,NFB) #B*T*N, NFB actions_scores=self.fc_actions(boxes_states_flat) #B*T*N, actn_num # Predict activities boxes_states_pooled,_=torch.max(boxes_states,dim=2) #B, T, NFB boxes_states_pooled_flat=boxes_states_pooled.reshape(-1,NFB) #B*T, NFB activities_scores=self.fc_activities(boxes_states_pooled_flat) #B*T, acty_num if T!=1: actions_scores=actions_scores.reshape(B,T,N,-1).mean(dim=1).reshape(B*N,-1) activities_scores=activities_scores.reshape(B,T,-1).mean(dim=1) return actions_scores, activities_scores class Basenet_collective(nn.Module): """ main module of base model for collective dataset """ def __init__(self, cfg): super(Basenet_collective, self).__init__() self.cfg=cfg D=self.cfg.emb_features K=self.cfg.crop_size[0] NFB=self.cfg.num_features_boxes NFR, NFG=self.cfg.num_features_relation, self.cfg.num_features_gcn self.backbone=MyInception_v3(transform_input=False,pretrained=True) # self.backbone=MyVGG16(pretrained=True) if not self.cfg.train_backbone: for p in self.backbone.parameters(): p.requires_grad=False self.roi_align=RoIAlign(*self.cfg.crop_size) self.fc_emb_1=nn.Linear(K*K*D,NFB) self.dropout_emb_1 = nn.Dropout(p=self.cfg.train_dropout_prob) # self.nl_emb_1=nn.LayerNorm([NFB]) self.fc_actions=nn.Linear(NFB,self.cfg.num_actions) self.fc_activities=nn.Linear(NFB,self.cfg.num_activities) for m in self.modules(): if isinstance(m,nn.Linear): nn.init.kaiming_normal_(m.weight) def savemodel(self,filepath): state = { 'backbone_state_dict': self.backbone.state_dict(), 'fc_emb_state_dict':self.fc_emb_1.state_dict(), 'fc_actions_state_dict':self.fc_actions.state_dict(), 'fc_activities_state_dict':self.fc_activities.state_dict() } torch.save(state, filepath) print('model saved to:',filepath) def loadmodel(self,filepath): state = torch.load(filepath) self.backbone.load_state_dict(state['backbone_state_dict']) self.fc_emb_1.load_state_dict(state['fc_emb_state_dict']) print('Load model states from: ',filepath) def forward(self,batch_data): images_in, boxes_in, bboxes_num_in = batch_data # read config parameters B=images_in.shape[0] T=images_in.shape[1] H, W=self.cfg.image_size OH, OW=self.cfg.out_size MAX_N=self.cfg.num_boxes NFB=self.cfg.num_features_boxes NFR, NFG=self.cfg.num_features_relation, self.cfg.num_features_gcn EPS=1e-5 D=self.cfg.emb_features K=self.cfg.crop_size[0] # Reshape the input data images_in_flat=torch.reshape(images_in,(B*T,3,H,W)) #B*T, 3, H, W boxes_in=boxes_in.reshape(B*T,MAX_N,4) # Use backbone to extract features of images_in # Pre-precess first images_in_flat=prep_images(images_in_flat) outputs=self.backbone(images_in_flat) # Build multiscale features features_multiscale=[] for features in outputs: if features.shape[2:4]!=torch.Size([OH,OW]): features=F.interpolate(features,size=(OH,OW),mode='bilinear',align_corners=True) features_multiscale.append(features) features_multiscale=torch.cat(features_multiscale,dim=1) #B*T, D, OH, OW boxes_in_flat=torch.reshape(boxes_in,(B*T*MAX_N,4)) #B*T*MAX_N, 4 boxes_idx=[i * torch.ones(MAX_N, dtype=torch.int) for i in range(B*T) ] boxes_idx=torch.stack(boxes_idx).to(device=boxes_in.device) # B*T, MAX_N boxes_idx_flat=torch.reshape(boxes_idx,(B*T*MAX_N,)) #B*T*MAX_N, # RoI Align boxes_in_flat.requires_grad=False boxes_idx_flat.requires_grad=False boxes_features_all=self.roi_align(features_multiscale, boxes_in_flat, boxes_idx_flat) #B*T*MAX_N, D, K, K, boxes_features_all=boxes_features_all.reshape(B*T,MAX_N,-1) #B*T,MAX_N, D*K*K # Embedding boxes_features_all=self.fc_emb_1(boxes_features_all) # B*T,MAX_N, NFB boxes_features_all=F.relu(boxes_features_all) boxes_features_all=self.dropout_emb_1(boxes_features_all) actions_scores=[] activities_scores=[] bboxes_num_in=bboxes_num_in.reshape(B*T,) #B*T, for bt in range(B*T): N=bboxes_num_in[bt] boxes_features=boxes_features_all[bt,:N,:].reshape(1,N,NFB) #1,N,NFB boxes_states=boxes_features NFS=NFB # Predict actions boxes_states_flat=boxes_states.reshape(-1,NFS) #1*N, NFS actn_score=self.fc_actions(boxes_states_flat) #1*N, actn_num actions_scores.append(actn_score) # Predict activities boxes_states_pooled,_=torch.max(boxes_states,dim=1) #1, NFS boxes_states_pooled_flat=boxes_states_pooled.reshape(-1,NFS) #1, NFS acty_score=self.fc_activities(boxes_states_pooled_flat) #1, acty_num activities_scores.append(acty_score) actions_scores=torch.cat(actions_scores,dim=0) #ALL_N,actn_num activities_scores=torch.cat(activities_scores,dim=0) #B*T,acty_num # print(actions_scores.shape) # print(activities_scores.shape) return actions_scores, activities_scores ================================================ FILE: collective.py ================================================ import torch from torch.utils import data import torchvision.models as models import torchvision.transforms as transforms import random from PIL import Image import numpy as np from collections import Counter FRAMES_NUM={1: 302, 2: 347, 3: 194, 4: 257, 5: 536, 6: 401, 7: 968, 8: 221, 9: 356, 10: 302, 11: 1813, 12: 1084, 13: 851, 14: 723, 15: 464, 16: 1021, 17: 905, 18: 600, 19: 203, 20: 342, 21: 650, 22: 361, 23: 311, 24: 321, 25: 617, 26: 734, 27: 1804, 28: 470, 29: 635, 30: 356, 31: 690, 32: 194, 33: 193, 34: 395, 35: 707, 36: 914, 37: 1049, 38: 653, 39: 518, 40: 401, 41: 707, 42: 420, 43: 410, 44: 356} FRAMES_SIZE={1: (480, 720), 2: (480, 720), 3: (480, 720), 4: (480, 720), 5: (480, 720), 6: (480, 720), 7: (480, 720), 8: (480, 720), 9: (480, 720), 10: (480, 720), 11: (480, 720), 12: (480, 720), 13: (480, 720), 14: (480, 720), 15: (450, 800), 16: (480, 720), 17: (480, 720), 18: (480, 720), 19: (480, 720), 20: (450, 800), 21: (450, 800), 22: (450, 800), 23: (450, 800), 24: (450, 800), 25: (480, 720), 26: (480, 720), 27: (480, 720), 28: (480, 720), 29: (480, 720), 30: (480, 720), 31: (480, 720), 32: (480, 720), 33: (480, 720), 34: (480, 720), 35: (480, 720), 36: (480, 720), 37: (480, 720), 38: (480, 720), 39: (480, 720), 40: (480, 720), 41: (480, 720), 42: (480, 720), 43: (480, 720), 44: (480, 720)} ACTIONS=['NA','Crossing','Waiting','Queueing','Walking','Talking'] ACTIVITIES=['Crossing','Waiting','Queueing','Walking','Talking'] ACTIONS_ID={a:i for i,a in enumerate(ACTIONS)} ACTIVITIES_ID={a:i for i,a in enumerate(ACTIVITIES)} def collective_read_annotations(path,sid): annotations={} path=path + '/seq%02d/annotations.txt' % sid with open(path,mode='r') as f: frame_id=None group_activity=None actions=[] bboxes=[] for l in f.readlines(): values=l[:-1].split(' ') if int(values[0])!=frame_id: if frame_id!=None and frame_id%10==1 and frame_id+9<=FRAMES_NUM[sid]: counter = Counter(actions).most_common(2) group_activity= counter[0][0]-1 if counter[0][0]!=0 else counter[1][0]-1 annotations[frame_id]={ 'frame_id':frame_id, 'group_activity':group_activity, 'actions':actions, 'bboxes':bboxes } frame_id=int(values[0]) group_activity=None actions=[] bboxes=[] actions.append(int(values[5])-1) x,y,w,h = (int(values[i]) for i in range(1,5)) H,W=FRAMES_SIZE[sid] bboxes.append( (y/H,x/W,(y+h)/H,(x+w)/W) ) if frame_id!=None and frame_id%10==1 and frame_id+9<=FRAMES_NUM[sid]: counter = Counter(actions).most_common(2) group_activity= counter[0][0]-1 if counter[0][0]!=0 else counter[1][0]-1 annotations[frame_id]={ 'frame_id':frame_id, 'group_activity':group_activity, 'actions':actions, 'bboxes':bboxes } return annotations def collective_read_dataset(path,seqs): data = {} for sid in seqs: data[sid] = collective_read_annotations(path,sid) return data def collective_all_frames(anns): return [(s,f) for s in anns for f in anns[s] ] class CollectiveDataset(data.Dataset): """ Characterize collective dataset for pytorch """ def __init__(self,anns,frames,images_path,image_size,feature_size,num_boxes=13,num_frames=10,is_training=True,is_finetune=False): self.anns=anns self.frames=frames self.images_path=images_path self.image_size=image_size self.feature_size=feature_size self.num_boxes=num_boxes self.num_frames=num_frames self.is_training=is_training self.is_finetune=is_finetune def __len__(self): """ Return the total number of samples """ return len(self.frames) def __getitem__(self,index): """ Generate one sample of the dataset """ select_frames=self.get_frames(self.frames[index]) sample=self.load_samples_sequence(select_frames) return sample def get_frames(self,frame): sid, src_fid = frame if self.is_finetune: if self.is_training: fid=random.randint(src_fid, src_fid+self.num_frames-1) return [(sid, src_fid, fid)] else: return [(sid, src_fid, fid) for fid in range(src_fid, src_fid+self.num_frames)] else: if self.is_training: sample_frames=random.sample(range(src_fid,src_fid+self.num_frames),3) return [(sid, src_fid, fid) for fid in sample_frames] else: sample_frames=[ src_fid, src_fid+3, src_fid+6, src_fid+1, src_fid+4, src_fid+7, src_fid+2, src_fid+5, src_fid+8 ] return [(sid, src_fid, fid) for fid in sample_frames] def load_samples_sequence(self,select_frames): """ load samples sequence Returns: pytorch tensors """ OH, OW=self.feature_size images, bboxes = [], [] activities, actions = [], [] bboxes_num=[] for i, (sid, src_fid, fid) in enumerate(select_frames): img = Image.open(self.images_path + '/seq%02d/frame%04d.jpg'%(sid,fid)) img=transforms.functional.resize(img,self.image_size) img=np.array(img) # H,W,3 -> 3,H,W img=img.transpose(2,0,1) images.append(img) temp_boxes=[] for box in self.anns[sid][src_fid]['bboxes']: y1,x1,y2,x2=box w1,h1,w2,h2 = x1*OW, y1*OH, x2*OW, y2*OH temp_boxes.append((w1,h1,w2,h2)) temp_actions=self.anns[sid][src_fid]['actions'][:] bboxes_num.append(len(temp_boxes)) while len(temp_boxes)!=self.num_boxes: temp_boxes.append((0,0,0,0)) temp_actions.append(-1) bboxes.append(temp_boxes) actions.append(temp_actions) activities.append(self.anns[sid][src_fid]['group_activity']) images = np.stack(images) activities = np.array(activities, dtype=np.int32) bboxes_num = np.array(bboxes_num, dtype=np.int32) bboxes=np.array(bboxes,dtype=np.float).reshape(-1,self.num_boxes,4) actions=np.array(actions,dtype=np.int32).reshape(-1,self.num_boxes) #convert to pytorch tensor images=torch.from_numpy(images).float() bboxes=torch.from_numpy(bboxes).float() actions=torch.from_numpy(actions).long() activities=torch.from_numpy(activities).long() bboxes_num=torch.from_numpy(bboxes_num).int() return images, bboxes, actions, activities, bboxes_num ================================================ FILE: config.py ================================================ import time import os class Config(object): """ class to save config parameter """ def __init__(self, dataset_name): # Global self.image_size = 720, 1280 #input image size self.batch_size = 32 #train batch size self.test_batch_size = 8 #test batch size self.num_boxes = 12 #max number of bounding boxes in each frame # Gpu self.use_gpu=True self.use_multi_gpu=True self.device_list="0,1,2,3" #id list of gpus used for training # Dataset assert(dataset_name in ['volleyball', 'collective']) self.dataset_name=dataset_name if dataset_name=='volleyball': self.data_path='data/volleyball' #data path for the volleyball dataset self.train_seqs = [ 1,3,6,7,10,13,15,16,18,22,23,31,32,36,38,39,40,41,42,48,50,52,53,54, 0,2,8,12,17,19,24,26,27,28,30,33,46,49,51] #video id list of train set self.test_seqs = [4,5,9,11,14,20,21,25,29,34,35,37,43,44,45,47] #video id list of test set else: self.data_path='data/collective' #data path for the collective dataset self.test_seqs=[5,6,7,8,9,10,11,15,16,25,28,29] self.train_seqs=[s for s in range(1,45) if s not in self.test_seqs] # Backbone self.backbone='inv3' self.crop_size = 5, 5 #crop size of roi align self.train_backbone = False #if freeze the feature extraction part of network, True for stage 1, False for stage 2 self.out_size = 87, 157 #output feature map size of backbone self.emb_features=1056 #output feature map channel of backbone # Activity Action self.num_actions = 9 #number of action categories self.num_activities = 8 #number of activity categories self.actions_loss_weight = 1.0 #weight used to balance action loss and activity loss self.actions_weights = None # Sample self.num_frames = 3 self.num_before = 5 self.num_after = 4 # GCN self.num_features_boxes = 1024 self.num_features_relation=256 self.num_graph=16 #number of graphs self.num_features_gcn=self.num_features_boxes self.gcn_layers=1 #number of GCN layers self.tau_sqrt=False self.pos_threshold=0.2 #distance mask threshold in position relation # Training Parameters self.train_random_seed = 0 self.train_learning_rate = 2e-4 #initial learning rate self.lr_plan = {41:1e-4, 81:5e-5, 121:1e-5} #change learning rate in these epochs self.train_dropout_prob = 0.3 #dropout probability self.weight_decay = 0 #l2 weight decay self.max_epoch=150 #max training epoch self.test_interval_epoch=2 # Exp self.training_stage=1 #specify stage1 or stage2 self.stage1_model_path='' #path of the base model, need to be set in stage2 self.test_before_train=False self.exp_note='Group-Activity-Recognition' self.exp_name=None def init_config(self, need_new_folder=True): if self.exp_name is None: time_str=time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) self.exp_name='[%s_stage%d]<%s>'%(self.exp_note,self.training_stage,time_str) self.result_path='result/%s'%self.exp_name self.log_path='result/%s/log.txt'%self.exp_name if need_new_folder: os.mkdir(self.result_path) ================================================ FILE: data/collective/tracks/readTracks.m ================================================ function tracks = readTracks(filename) fp = fopen(filename, 'r'); tline = fgetl(fp); nframe = sscanf(tline, 'Total frames %d'); tline = fgetl(fp); ntargets = sscanf(tline, 'Number of Targets %d'); for n = 1:ntargets track = struct('id', n, 'ti', 0, 'te', 0, 'bbs', [], 'locs', []); tline = fgetl(fp); temp = sscanf(tline, 'Target %d (frames from %d to %d)'); track.id = temp(1); track.ti = temp(2); track.te = temp(3); len = temp(3) - temp(2) + 1; tline = fgetl(fp); % dummy line for t = 1:len tline = fgetl(fp); temp = sscanf(tline, '%d\t%d\t%d\t%d\t%d'); track.bbs(:, t) = temp(2:5); end tline = fgetl(fp); % dummy line for t = 1:len tline = fgetl(fp); temp = sscanf(tline, '%d\t%f\t%f\t%f\t%f'); track.locs(:, t) = temp(2:5); end tracks(n) = track; end fclose(fp); end ================================================ FILE: data/collective/tracks/showTracks.m ================================================ function showTracks(imdir, tracks) imfiles = dir([imdir '*.jpg']); for i = 1:length(imfiles) imshow([imdir imfiles(i).name]); drawTracks(tracks, i); drawnow; end end function drawTracks(tracks, frame) cmap = colormap; for i = 1:length(tracks) if ((tracks(i).ti <= frame) & ... (tracks(i).te >= frame)) idx = frame - tracks(i).ti + 1; col = cmap(mod(i*10, 64) + 1, :); rectangle('Position', tracks(i).bbs(:, idx), 'EdgeColor', col, 'LineWidth', 3); end end end ================================================ FILE: data/volleyball/tracks_normalized.pkl ================================================ [File too large to display: 23.9 MB] ================================================ FILE: dataset.py ================================================ from volleyball import * from collective import * import pickle def return_dataset(cfg): if cfg.dataset_name=='volleyball': train_anns = volley_read_dataset(cfg.data_path, cfg.train_seqs) train_frames = volley_all_frames(train_anns) test_anns = volley_read_dataset(cfg.data_path, cfg.test_seqs) test_frames = volley_all_frames(test_anns) all_anns = {**train_anns, **test_anns} all_tracks = pickle.load(open(cfg.data_path + '/tracks_normalized.pkl', 'rb')) training_set=VolleyballDataset(all_anns,all_tracks,train_frames, cfg.data_path,cfg.image_size,cfg.out_size,num_before=cfg.num_before, num_after=cfg.num_after,is_training=True,is_finetune=(cfg.training_stage==1)) validation_set=VolleyballDataset(all_anns,all_tracks,test_frames, cfg.data_path,cfg.image_size,cfg.out_size,num_before=cfg.num_before, num_after=cfg.num_after,is_training=False,is_finetune=(cfg.training_stage==1)) elif cfg.dataset_name=='collective': train_anns=collective_read_dataset(cfg.data_path, cfg.train_seqs) train_frames=collective_all_frames(train_anns) test_anns=collective_read_dataset(cfg.data_path, cfg.test_seqs) test_frames=collective_all_frames(test_anns) training_set=CollectiveDataset(train_anns,train_frames, cfg.data_path,cfg.image_size,cfg.out_size, num_frames=cfg.num_frames,is_training=True,is_finetune=(cfg.training_stage==1)) validation_set=CollectiveDataset(test_anns,test_frames, cfg.data_path,cfg.image_size,cfg.out_size, num_frames=cfg.num_frames,is_training=False,is_finetune=(cfg.training_stage==1)) else: assert False print('Reading dataset finished...') print('%d train samples'%len(train_frames)) print('%d test samples'%len(test_frames)) return training_set, validation_set ================================================ FILE: gcn_model.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from backbone import * from utils import * from roi_align.roi_align import RoIAlign # RoIAlign module from roi_align.roi_align import CropAndResize # crop_and_resize module class GCN_Module(nn.Module): def __init__(self, cfg): super(GCN_Module, self).__init__() self.cfg=cfg NFR =cfg.num_features_relation NG=cfg.num_graph N=cfg.num_boxes T=cfg.num_frames NFG=cfg.num_features_gcn NFG_ONE=NFG self.fc_rn_theta_list=torch.nn.ModuleList([ nn.Linear(NFG,NFR) for i in range(NG) ]) self.fc_rn_phi_list=torch.nn.ModuleList([ nn.Linear(NFG,NFR) for i in range(NG) ]) self.fc_gcn_list=torch.nn.ModuleList([ nn.Linear(NFG,NFG_ONE,bias=False) for i in range(NG) ]) if cfg.dataset_name=='volleyball': self.nl_gcn_list=torch.nn.ModuleList([ nn.LayerNorm([T*N,NFG_ONE]) for i in range(NG) ]) else: self.nl_gcn_list=torch.nn.ModuleList([ nn.LayerNorm([NFG_ONE]) for i in range(NG) ]) def forward(self,graph_boxes_features,boxes_in_flat): """ graph_boxes_features [B*T,N,NFG] """ # GCN graph modeling # Prepare boxes similarity relation B,N,NFG=graph_boxes_features.shape NFR=self.cfg.num_features_relation NG=self.cfg.num_graph NFG_ONE=NFG OH, OW=self.cfg.out_size pos_threshold=self.cfg.pos_threshold # Prepare position mask graph_boxes_positions=boxes_in_flat #B*T*N, 4 graph_boxes_positions[:,0]=(graph_boxes_positions[:,0] + graph_boxes_positions[:,2]) / 2 graph_boxes_positions[:,1]=(graph_boxes_positions[:,1] + graph_boxes_positions[:,3]) / 2 graph_boxes_positions=graph_boxes_positions[:,:2].reshape(B,N,2) #B*T, N, 2 graph_boxes_distances=calc_pairwise_distance_3d(graph_boxes_positions,graph_boxes_positions) #B, N, N position_mask=( graph_boxes_distances > (pos_threshold*OW) ) relation_graph=None graph_boxes_features_list=[] for i in range(NG): graph_boxes_features_theta=self.fc_rn_theta_list[i](graph_boxes_features) #B,N,NFR graph_boxes_features_phi=self.fc_rn_phi_list[i](graph_boxes_features) #B,N,NFR # graph_boxes_features_theta=self.nl_rn_theta_list[i](graph_boxes_features_theta) # graph_boxes_features_phi=self.nl_rn_phi_list[i](graph_boxes_features_phi) similarity_relation_graph=torch.matmul(graph_boxes_features_theta,graph_boxes_features_phi.transpose(1,2)) #B,N,N similarity_relation_graph=similarity_relation_graph/np.sqrt(NFR) similarity_relation_graph=similarity_relation_graph.reshape(-1,1) #B*N*N, 1 # Build relation graph relation_graph=similarity_relation_graph relation_graph = relation_graph.reshape(B,N,N) relation_graph[position_mask]=-float('inf') relation_graph = torch.softmax(relation_graph,dim=2) # Graph convolution one_graph_boxes_features=self.fc_gcn_list[i]( torch.matmul(relation_graph,graph_boxes_features) ) #B, N, NFG_ONE one_graph_boxes_features=self.nl_gcn_list[i](one_graph_boxes_features) one_graph_boxes_features=F.relu(one_graph_boxes_features) graph_boxes_features_list.append(one_graph_boxes_features) graph_boxes_features=torch.sum(torch.stack(graph_boxes_features_list),dim=0) #B, N, NFG return graph_boxes_features,relation_graph class GCNnet_volleyball(nn.Module): """ main module of GCN for the volleyball dataset """ def __init__(self, cfg): super(GCNnet_volleyball, self).__init__() self.cfg=cfg T, N=self.cfg.num_frames, self.cfg.num_boxes D=self.cfg.emb_features K=self.cfg.crop_size[0] NFB=self.cfg.num_features_boxes NFR, NFG=self.cfg.num_features_relation, self.cfg.num_features_gcn NG=self.cfg.num_graph if cfg.backbone=='inv3': self.backbone=MyInception_v3(transform_input=False,pretrained=True) elif cfg.backbone=='vgg16': self.backbone=MyVGG16(pretrained=True) elif cfg.backbone=='vgg19': self.backbone=MyVGG19(pretrained=False) else: assert False if not cfg.train_backbone: for p in self.backbone.parameters(): p.requires_grad=False self.roi_align=RoIAlign(*self.cfg.crop_size) self.fc_emb_1=nn.Linear(K*K*D,NFB) self.nl_emb_1=nn.LayerNorm([NFB]) self.gcn_list = torch.nn.ModuleList([ GCN_Module(self.cfg) for i in range(self.cfg.gcn_layers) ]) self.dropout_global=nn.Dropout(p=self.cfg.train_dropout_prob) self.fc_actions=nn.Linear(NFG,self.cfg.num_actions) self.fc_activities=nn.Linear(NFG,self.cfg.num_activities) for m in self.modules(): if isinstance(m,nn.Linear): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def loadmodel(self,filepath): state = torch.load(filepath) self.backbone.load_state_dict(state['backbone_state_dict']) self.fc_emb_1.load_state_dict(state['fc_emb_state_dict']) print('Load model states from: ',filepath) def forward(self,batch_data): images_in, boxes_in = batch_data # read config parameters B=images_in.shape[0] T=images_in.shape[1] H, W=self.cfg.image_size OH, OW=self.cfg.out_size N=self.cfg.num_boxes NFB=self.cfg.num_features_boxes NFR, NFG=self.cfg.num_features_relation, self.cfg.num_features_gcn NG=self.cfg.num_graph D=self.cfg.emb_features K=self.cfg.crop_size[0] if not self.training: B=B*3 T=T//3 images_in.reshape( (B,T)+images_in.shape[2:] ) boxes_in.reshape( (B,T)+boxes_in.shape[2:] ) # Reshape the input data images_in_flat=torch.reshape(images_in,(B*T,3,H,W)) #B*T, 3, H, W boxes_in_flat=torch.reshape(boxes_in,(B*T*N,4)) #B*T*N, 4 boxes_idx=[i * torch.ones(N, dtype=torch.int) for i in range(B*T) ] boxes_idx=torch.stack(boxes_idx).to(device=boxes_in.device) # B*T, N boxes_idx_flat=torch.reshape(boxes_idx,(B*T*N,)) #B*T*N, # Use backbone to extract features of images_in # Pre-precess first images_in_flat=prep_images(images_in_flat) outputs=self.backbone(images_in_flat) # Build features assert outputs[0].shape[2:4]==torch.Size([OH,OW]) features_multiscale=[] for features in outputs: if features.shape[2:4]!=torch.Size([OH,OW]): features=F.interpolate(features,size=(OH,OW),mode='bilinear',align_corners=True) features_multiscale.append(features) features_multiscale=torch.cat(features_multiscale,dim=1) #B*T, D, OH, OW # RoI Align boxes_in_flat.requires_grad=False boxes_idx_flat.requires_grad=False boxes_features=self.roi_align(features_multiscale, boxes_in_flat, boxes_idx_flat) #B*T*N, D, K, K, boxes_features=boxes_features.reshape(B,T,N,-1) #B,T,N, D*K*K # Embedding boxes_features=self.fc_emb_1(boxes_features) # B,T,N, NFB boxes_features=self.nl_emb_1(boxes_features) boxes_features=F.relu(boxes_features) # GCN graph_boxes_features=boxes_features.reshape(B,T*N,NFG) # visual_info=[] for i in range(len(self.gcn_list)): graph_boxes_features,relation_graph=self.gcn_list[i](graph_boxes_features,boxes_in_flat) # visual_info.append(relation_graph.reshape(B,T,N,N)) # fuse graph_boxes_features with boxes_features graph_boxes_features=graph_boxes_features.reshape(B,T,N,NFG) boxes_features=boxes_features.reshape(B,T,N,NFB) # boxes_states= torch.cat( [graph_boxes_features,boxes_features],dim=3) #B, T, N, NFG+NFB boxes_states=graph_boxes_features+boxes_features boxes_states=self.dropout_global(boxes_states) NFS=NFG # Predict actions boxes_states_flat=boxes_states.reshape(-1,NFS) #B*T*N, NFS actions_scores=self.fc_actions(boxes_states_flat) #B*T*N, actn_num # Predict activities boxes_states_pooled,_=torch.max(boxes_states,dim=2) boxes_states_pooled_flat=boxes_states_pooled.reshape(-1,NFS) activities_scores=self.fc_activities(boxes_states_pooled_flat) #B*T, acty_num # Temporal fusion actions_scores=actions_scores.reshape(B,T,N,-1) actions_scores=torch.mean(actions_scores,dim=1).reshape(B*N,-1) activities_scores=activities_scores.reshape(B,T,-1) activities_scores=torch.mean(activities_scores,dim=1).reshape(B,-1) if not self.training: B=B//3 actions_scores=torch.mean(actions_scores.reshape(B,3,N,-1),dim=1).reshape(B*N,-1) activities_scores=torch.mean(activities_scores.reshape(B,3,-1),dim=1).reshape(B,-1) return actions_scores, activities_scores class GCNnet_collective(nn.Module): """ main module of GCN for the collective dataset """ def __init__(self, cfg): super(GCNnet_collective, self).__init__() self.cfg=cfg D=self.cfg.emb_features K=self.cfg.crop_size[0] NFB=self.cfg.num_features_boxes NFR, NFG=self.cfg.num_features_relation, self.cfg.num_features_gcn self.backbone=MyInception_v3(transform_input=False,pretrained=True) if not self.cfg.train_backbone: for p in self.backbone.parameters(): p.requires_grad=False self.roi_align=RoIAlign(*self.cfg.crop_size) self.fc_emb_1=nn.Linear(K*K*D,NFB) self.nl_emb_1=nn.LayerNorm([NFB]) self.gcn_list = torch.nn.ModuleList([ GCN_Module(self.cfg) for i in range(self.cfg.gcn_layers) ]) self.dropout_global=nn.Dropout(p=self.cfg.train_dropout_prob) self.fc_actions=nn.Linear(NFG,self.cfg.num_actions) self.fc_activities=nn.Linear(NFG,self.cfg.num_activities) for m in self.modules(): if isinstance(m,nn.Linear): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) # nn.init.zeros_(self.fc_gcn_3.weight) def loadmodel(self,filepath): state = torch.load(filepath) self.backbone.load_state_dict(state['backbone_state_dict']) self.fc_emb_1.load_state_dict(state['fc_emb_state_dict']) print('Load model states from: ',filepath) def forward(self,batch_data): images_in, boxes_in, bboxes_num_in = batch_data # read config parameters B=images_in.shape[0] T=images_in.shape[1] H, W=self.cfg.image_size OH, OW=self.cfg.out_size MAX_N=self.cfg.num_boxes NFB=self.cfg.num_features_boxes NFR, NFG=self.cfg.num_features_relation, self.cfg.num_features_gcn D=self.cfg.emb_features K=self.cfg.crop_size[0] if not self.training: B=B*3 T=T//3 images_in.reshape( (B,T)+images_in.shape[2:] ) boxes_in.reshape( (B,T)+boxes_in.shape[2:] ) bboxes_num_in.reshape((B,T)) # Reshape the input data images_in_flat=torch.reshape(images_in,(B*T,3,H,W)) #B*T, 3, H, W boxes_in=boxes_in.reshape(B*T,MAX_N,4) # Use backbone to extract features of images_in # Pre-precess first images_in_flat=prep_images(images_in_flat) outputs=self.backbone(images_in_flat) # Build multiscale features features_multiscale=[] for features in outputs: if features.shape[2:4]!=torch.Size([OH,OW]): features=F.interpolate(features,size=(OH,OW),mode='bilinear',align_corners=True) features_multiscale.append(features) features_multiscale=torch.cat(features_multiscale,dim=1) #B*T, D, OH, OW boxes_in_flat=torch.reshape(boxes_in,(B*T*MAX_N,4)) #B*T*MAX_N, 4 boxes_idx=[i * torch.ones(MAX_N, dtype=torch.int) for i in range(B*T) ] boxes_idx=torch.stack(boxes_idx).to(device=boxes_in.device) # B*T, MAX_N boxes_idx_flat=torch.reshape(boxes_idx,(B*T*MAX_N,)) #B*T*MAX_N, # RoI Align boxes_in_flat.requires_grad=False boxes_idx_flat.requires_grad=False boxes_features_all=self.roi_align(features_multiscale, boxes_in_flat, boxes_idx_flat) #B*T*MAX_N, D, K, K, boxes_features_all=boxes_features_all.reshape(B*T,MAX_N,-1) #B*T,MAX_N, D*K*K # Embedding boxes_features_all=self.fc_emb_1(boxes_features_all) # B*T,MAX_N, NFB boxes_features_all=self.nl_emb_1(boxes_features_all) boxes_features_all=F.relu(boxes_features_all) boxes_features_all=boxes_features_all.reshape(B,T,MAX_N,NFB) boxes_in=boxes_in.reshape(B,T,MAX_N,4) actions_scores=[] activities_scores=[] bboxes_num_in=bboxes_num_in.reshape(B,T) #B,T, for b in range(B): N=bboxes_num_in[b][0] boxes_features=boxes_features_all[b,:,:N,:].reshape(1,T*N,NFB) #1,T,N,NFB boxes_positions=boxes_in[b,:,:N,:].reshape(T*N,4) #T*N, 4 # GCN graph modeling for i in range(len(self.gcn_list)): graph_boxes_features,relation_graph=self.gcn_list[i](boxes_features,boxes_positions) # cat graph_boxes_features with boxes_features boxes_features=boxes_features.reshape(1,T*N,NFB) boxes_states=graph_boxes_features+boxes_features #1, T*N, NFG boxes_states=self.dropout_global(boxes_states) NFS=NFG boxes_states=boxes_states.reshape(T,N,NFS) # Predict actions actn_score=self.fc_actions(boxes_states) #T,N, actn_num # Predict activities boxes_states_pooled,_=torch.max(boxes_states,dim=1) #T, NFS acty_score=self.fc_activities(boxes_states_pooled) #T, acty_num # GSN fusion actn_score=torch.mean(actn_score,dim=0).reshape(N,-1) #N, actn_num acty_score=torch.mean(acty_score,dim=0).reshape(1,-1) #1, acty_num actions_scores.append(actn_score) activities_scores.append(acty_score) actions_scores=torch.cat(actions_scores,dim=0) #ALL_N,actn_num activities_scores=torch.cat(activities_scores,dim=0) #B,acty_num if not self.training: B=B//3 actions_scores=torch.mean(actions_scores.reshape(-1,3,actions_scores.shape[1]),dim=1) activities_scores=torch.mean(activities_scores.reshape(B,3,-1),dim=1).reshape(B,-1) # print(actions_scores.shape) # print(activities_scores.shape) return actions_scores, activities_scores ================================================ FILE: result/.gitkeep ================================================ ================================================ FILE: scripts/train_collective_stage1.py ================================================ import sys sys.path.append(".") from train_net import * cfg=Config('collective') cfg.device_list="0,1" cfg.training_stage=1 cfg.train_backbone=True cfg.image_size=480, 720 cfg.out_size=57,87 cfg.num_boxes=13 cfg.num_actions=6 cfg.num_activities=5 cfg.num_frames=10 cfg.batch_size=16 cfg.test_batch_size=8 cfg.train_learning_rate=1e-5 cfg.train_dropout_prob=0.5 cfg.weight_decay=1e-2 cfg.lr_plan={} cfg.max_epoch=100 cfg.exp_note='Collective_stage1' train_net(cfg) ================================================ FILE: scripts/train_collective_stage2.py ================================================ import sys sys.path.append(".") from train_net import * cfg=Config('collective') cfg.device_list="0,1" cfg.training_stage=2 cfg.stage1_model_path='result/STAGE1_MODEL.pth' #PATH OF THE BASE MODEL cfg.train_backbone=False cfg.image_size=480, 720 cfg.out_size=57,87 cfg.num_boxes=13 cfg.num_actions=6 cfg.num_activities=5 cfg.num_frames=10 cfg.num_graph=4 cfg.tau_sqrt=True cfg.batch_size=16 cfg.test_batch_size=8 cfg.train_learning_rate=1e-4 cfg.train_dropout_prob=0.2 cfg.weight_decay=1e-2 cfg.lr_plan={} cfg.max_epoch=50 cfg.exp_note='Collective_stage2' train_net(cfg) ================================================ FILE: scripts/train_volleyball_stage1.py ================================================ import sys sys.path.append(".") from train_net import * cfg=Config('volleyball') cfg.device_list="0,1,2,3" cfg.training_stage=1 cfg.stage1_model_path='' cfg.train_backbone=True cfg.batch_size=8 cfg.test_batch_size=4 cfg.num_frames=1 cfg.train_learning_rate=1e-5 cfg.lr_plan={} cfg.max_epoch=200 cfg.actions_weights=[[1., 1., 2., 3., 1., 2., 2., 0.2, 1.]] cfg.exp_note='Volleyball_stage1' train_net(cfg) ================================================ FILE: scripts/train_volleyball_stage2.py ================================================ import sys sys.path.append(".") from train_net import * cfg=Config('volleyball') cfg.device_list="0,1,2,3" cfg.training_stage=2 cfg.stage1_model_path='result/STAGE1_MODEL.pth' #PATH OF THE BASE MODEL cfg.train_backbone=False cfg.batch_size=32 #32 cfg.test_batch_size=8 cfg.num_frames=3 cfg.train_learning_rate=2e-4 cfg.lr_plan={41:1e-4, 81:5e-5, 121:1e-5} cfg.max_epoch=150 cfg.actions_weights=[[1., 1., 2., 3., 1., 2., 2., 0.2, 1.]] cfg.exp_note='Volleyball_stage2' train_net(cfg) ================================================ FILE: train_net.py ================================================ import torch import torch.optim as optim import time import random import os import sys from config import * from volleyball import * from collective import * from dataset import * from gcn_model import * from base_model import * from utils import * def set_bn_eval(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.eval() def adjust_lr(optimizer, new_lr): print('change learning rate:',new_lr) for param_group in optimizer.param_groups: param_group['lr'] = new_lr def train_net(cfg): """ training gcn net """ os.environ['CUDA_VISIBLE_DEVICES']=cfg.device_list # Show config parameters cfg.init_config() show_config(cfg) # Reading dataset training_set,validation_set=return_dataset(cfg) params = { 'batch_size': cfg.batch_size, 'shuffle': True, 'num_workers': 4 } training_loader=data.DataLoader(training_set,**params) params['batch_size']=cfg.test_batch_size validation_loader=data.DataLoader(validation_set,**params) # Set random seed np.random.seed(cfg.train_random_seed) torch.manual_seed(cfg.train_random_seed) random.seed(cfg.train_random_seed) # Set data position if cfg.use_gpu and torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') # Build model and optimizer basenet_list={'volleyball':Basenet_volleyball, 'collective':Basenet_collective} gcnnet_list={'volleyball':GCNnet_volleyball, 'collective':GCNnet_collective} if cfg.training_stage==1: Basenet=basenet_list[cfg.dataset_name] model=Basenet(cfg) elif cfg.training_stage==2: GCNnet=gcnnet_list[cfg.dataset_name] model=GCNnet(cfg) # Load backbone model.loadmodel(cfg.stage1_model_path) else: assert(False) if cfg.use_multi_gpu: model=nn.DataParallel(model) model=model.to(device=device) model.train() model.apply(set_bn_eval) optimizer=optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),lr=cfg.train_learning_rate,weight_decay=cfg.weight_decay) train_list={'volleyball':train_volleyball, 'collective':train_collective} test_list={'volleyball':test_volleyball, 'collective':test_collective} train=train_list[cfg.dataset_name] test=test_list[cfg.dataset_name] if cfg.test_before_train: test_info=test(validation_loader, model, device, 0, cfg) print(test_info) # Training iteration best_result={'epoch':0, 'activities_acc':0} start_epoch=1 for epoch in range(start_epoch, start_epoch+cfg.max_epoch): if epoch in cfg.lr_plan: adjust_lr(optimizer, cfg.lr_plan[epoch]) # One epoch of forward and backward train_info=train(training_loader, model, device, optimizer, epoch, cfg) show_epoch_info('Train', cfg.log_path, train_info) # Test if epoch % cfg.test_interval_epoch == 0: test_info=test(validation_loader, model, device, epoch, cfg) show_epoch_info('Test', cfg.log_path, test_info) if test_info['activities_acc']>best_result['activities_acc']: best_result=test_info print_log(cfg.log_path, 'Best group activity accuracy: %.2f%% at epoch #%d.'%(best_result['activities_acc'], best_result['epoch'])) # Save model if cfg.training_stage==2: state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), } filepath=cfg.result_path+'/stage%d_epoch%d_%.2f%%.pth'%(cfg.training_stage,epoch,test_info['activities_acc']) torch.save(state, filepath) print('model saved to:',filepath) elif cfg.training_stage==1: for m in model.modules(): if isinstance(m, Basenet): filepath=cfg.result_path+'/stage%d_epoch%d_%.2f%%.pth'%(cfg.training_stage,epoch,test_info['activities_acc']) m.savemodel(filepath) # print('model saved to:',filepath) else: assert False def train_volleyball(data_loader, model, device, optimizer, epoch, cfg): actions_meter=AverageMeter() activities_meter=AverageMeter() loss_meter=AverageMeter() epoch_timer=Timer() for batch_data in data_loader: model.train() model.apply(set_bn_eval) # prepare batch data batch_data=[b.to(device=device) for b in batch_data] batch_size=batch_data[0].shape[0] num_frames=batch_data[0].shape[1] actions_in=batch_data[2].reshape((batch_size,num_frames,cfg.num_boxes)) activities_in=batch_data[3].reshape((batch_size,num_frames)) actions_in=actions_in[:,0,:].reshape((batch_size*cfg.num_boxes,)) activities_in=activities_in[:,0].reshape((batch_size,)) # forward actions_scores,activities_scores=model((batch_data[0],batch_data[1])) # Predict actions actions_weights=torch.tensor(cfg.actions_weights).to(device=device) actions_loss=F.cross_entropy(actions_scores,actions_in,weight=actions_weights) actions_labels=torch.argmax(actions_scores,dim=1) actions_correct=torch.sum(torch.eq(actions_labels.int(),actions_in.int()).float()) # Predict activities activities_loss=F.cross_entropy(activities_scores,activities_in) activities_labels=torch.argmax(activities_scores,dim=1) activities_correct=torch.sum(torch.eq(activities_labels.int(),activities_in.int()).float()) # Get accuracy actions_accuracy=actions_correct.item()/actions_scores.shape[0] activities_accuracy=activities_correct.item()/activities_scores.shape[0] actions_meter.update(actions_accuracy, actions_scores.shape[0]) activities_meter.update(activities_accuracy, activities_scores.shape[0]) # Total loss total_loss=activities_loss+cfg.actions_loss_weight*actions_loss loss_meter.update(total_loss.item(), batch_size) # Optim optimizer.zero_grad() total_loss.backward() optimizer.step() train_info={ 'time':epoch_timer.timeit(), 'epoch':epoch, 'loss':loss_meter.avg, 'activities_acc':activities_meter.avg*100, 'actions_acc':actions_meter.avg*100 } return train_info def test_volleyball(data_loader, model, device, epoch, cfg): model.eval() actions_meter=AverageMeter() activities_meter=AverageMeter() loss_meter=AverageMeter() epoch_timer=Timer() with torch.no_grad(): for batch_data_test in data_loader: # prepare batch data batch_data_test=[b.to(device=device) for b in batch_data_test] batch_size=batch_data_test[0].shape[0] num_frames=batch_data_test[0].shape[1] actions_in=batch_data_test[2].reshape((batch_size,num_frames,cfg.num_boxes)) activities_in=batch_data_test[3].reshape((batch_size,num_frames)) # forward actions_scores,activities_scores=model((batch_data_test[0],batch_data_test[1])) # Predict actions actions_in=actions_in[:,0,:].reshape((batch_size*cfg.num_boxes,)) activities_in=activities_in[:,0].reshape((batch_size,)) actions_weights=torch.tensor(cfg.actions_weights).to(device=device) actions_loss=F.cross_entropy(actions_scores,actions_in,weight=actions_weights) actions_labels=torch.argmax(actions_scores,dim=1) # Predict activities activities_loss=F.cross_entropy(activities_scores,activities_in) activities_labels=torch.argmax(activities_scores,dim=1) actions_correct=torch.sum(torch.eq(actions_labels.int(),actions_in.int()).float()) activities_correct=torch.sum(torch.eq(activities_labels.int(),activities_in.int()).float()) # Get accuracy actions_accuracy=actions_correct.item()/actions_scores.shape[0] activities_accuracy=activities_correct.item()/activities_scores.shape[0] actions_meter.update(actions_accuracy, actions_scores.shape[0]) activities_meter.update(activities_accuracy, activities_scores.shape[0]) # Total loss total_loss=activities_loss+cfg.actions_loss_weight*actions_loss loss_meter.update(total_loss.item(), batch_size) test_info={ 'time':epoch_timer.timeit(), 'epoch':epoch, 'loss':loss_meter.avg, 'activities_acc':activities_meter.avg*100, 'actions_acc':actions_meter.avg*100 } return test_info def train_collective(data_loader, model, device, optimizer, epoch, cfg): actions_meter=AverageMeter() activities_meter=AverageMeter() loss_meter=AverageMeter() epoch_timer=Timer() for batch_data in data_loader: model.train() model.apply(set_bn_eval) # prepare batch data batch_data=[b.to(device=device) for b in batch_data] batch_size=batch_data[0].shape[0] num_frames=batch_data[0].shape[1] # forward actions_scores,activities_scores=model((batch_data[0],batch_data[1],batch_data[4])) actions_in=batch_data[2].reshape((batch_size,num_frames,cfg.num_boxes)) activities_in=batch_data[3].reshape((batch_size,num_frames)) bboxes_num=batch_data[4].reshape(batch_size,num_frames) actions_in_nopad=[] if cfg.training_stage==1: actions_in=actions_in.reshape((batch_size*num_frames,cfg.num_boxes,)) bboxes_num=bboxes_num.reshape(batch_size*num_frames,) for bt in range(batch_size*num_frames): N=bboxes_num[bt] actions_in_nopad.append(actions_in[bt,:N]) else: for b in range(batch_size): N=bboxes_num[b][0] actions_in_nopad.append(actions_in[b][0][:N]) actions_in=torch.cat(actions_in_nopad,dim=0).reshape(-1,) #ALL_N, if cfg.training_stage==1: activities_in=activities_in.reshape(-1,) else: activities_in=activities_in[:,0].reshape(batch_size,) # Predict actions actions_loss=F.cross_entropy(actions_scores,actions_in,weight=None) actions_labels=torch.argmax(actions_scores,dim=1) #B*T*N, actions_correct=torch.sum(torch.eq(actions_labels.int(),actions_in.int()).float()) # Predict activities activities_loss=F.cross_entropy(activities_scores,activities_in) activities_labels=torch.argmax(activities_scores,dim=1) #B*T, activities_correct=torch.sum(torch.eq(activities_labels.int(),activities_in.int()).float()) # Get accuracy actions_accuracy=actions_correct.item()/actions_scores.shape[0] activities_accuracy=activities_correct.item()/activities_scores.shape[0] actions_meter.update(actions_accuracy, actions_scores.shape[0]) activities_meter.update(activities_accuracy, activities_scores.shape[0]) # Total loss total_loss=activities_loss+cfg.actions_loss_weight*actions_loss loss_meter.update(total_loss.item(), batch_size) # Optim optimizer.zero_grad() total_loss.backward() optimizer.step() train_info={ 'time':epoch_timer.timeit(), 'epoch':epoch, 'loss':loss_meter.avg, 'activities_acc':activities_meter.avg*100, 'actions_acc':actions_meter.avg*100 } return train_info def test_collective(data_loader, model, device, epoch, cfg): model.eval() actions_meter=AverageMeter() activities_meter=AverageMeter() loss_meter=AverageMeter() epoch_timer=Timer() with torch.no_grad(): for batch_data in data_loader: # prepare batch data batch_data=[b.to(device=device) for b in batch_data] batch_size=batch_data[0].shape[0] num_frames=batch_data[0].shape[1] actions_in=batch_data[2].reshape((batch_size,num_frames,cfg.num_boxes)) activities_in=batch_data[3].reshape((batch_size,num_frames)) bboxes_num=batch_data[4].reshape(batch_size,num_frames) # forward actions_scores,activities_scores=model((batch_data[0],batch_data[1],batch_data[4])) actions_in_nopad=[] if cfg.training_stage==1: actions_in=actions_in.reshape((batch_size*num_frames,cfg.num_boxes,)) bboxes_num=bboxes_num.reshape(batch_size*num_frames,) for bt in range(batch_size*num_frames): N=bboxes_num[bt] actions_in_nopad.append(actions_in[bt,:N]) else: for b in range(batch_size): N=bboxes_num[b][0] actions_in_nopad.append(actions_in[b][0][:N]) actions_in=torch.cat(actions_in_nopad,dim=0).reshape(-1,) #ALL_N, if cfg.training_stage==1: activities_in=activities_in.reshape(-1,) else: activities_in=activities_in[:,0].reshape(batch_size,) actions_loss=F.cross_entropy(actions_scores,actions_in) actions_labels=torch.argmax(actions_scores,dim=1) #ALL_N, actions_correct=torch.sum(torch.eq(actions_labels.int(),actions_in.int()).float()) # Predict activities activities_loss=F.cross_entropy(activities_scores,activities_in) activities_labels=torch.argmax(activities_scores,dim=1) #B, activities_correct=torch.sum(torch.eq(activities_labels.int(),activities_in.int()).float()) # Get accuracy actions_accuracy=actions_correct.item()/actions_scores.shape[0] activities_accuracy=activities_correct.item()/activities_scores.shape[0] actions_meter.update(actions_accuracy, actions_scores.shape[0]) activities_meter.update(activities_accuracy, activities_scores.shape[0]) # Total loss total_loss=activities_loss+cfg.actions_loss_weight*actions_loss loss_meter.update(total_loss.item(), batch_size) test_info={ 'time':epoch_timer.timeit(), 'epoch':epoch, 'loss':loss_meter.avg, 'activities_acc':activities_meter.avg*100, 'actions_acc':actions_meter.avg*100 } return test_info ================================================ FILE: utils.py ================================================ import torch import time def prep_images(images): """ preprocess images Args: images: pytorch tensor """ images = images.div(255.0) images = torch.sub(images,0.5) images = torch.mul(images,2.0) return images def calc_pairwise_distance(X, Y): """ computes pairwise distance between each element Args: X: [N,D] Y: [M,D] Returns: dist: [N,M] matrix of euclidean distances """ rx=X.pow(2).sum(dim=1).reshape((-1,1)) ry=Y.pow(2).sum(dim=1).reshape((-1,1)) dist=rx-2.0*X.matmul(Y.t())+ry.t() return torch.sqrt(dist) def calc_pairwise_distance_3d(X, Y): """ computes pairwise distance between each element Args: X: [B,N,D] Y: [B,M,D] Returns: dist: [B,N,M] matrix of euclidean distances """ B=X.shape[0] rx=X.pow(2).sum(dim=2).reshape((B,-1,1)) ry=Y.pow(2).sum(dim=2).reshape((B,-1,1)) dist=rx-2.0*X.matmul(Y.transpose(1,2))+ry.transpose(1,2) return torch.sqrt(dist) def sincos_encoding_2d(positions,d_emb): """ Args: positions: [N,2] Returns: positions high-dimensional representation: [N,d_emb] """ N=positions.shape[0] d=d_emb//2 idxs = [np.power(1000,2*(idx//2)/d) for idx in range(d)] idxs = torch.FloatTensor(idxs).to(device=positions.device) idxs = idxs.repeat(N,2) #N, d_emb pos = torch.cat([ positions[:,0].reshape(-1,1).repeat(1,d),positions[:,1].reshape(-1,1).repeat(1,d) ],dim=1) embeddings=pos/idxs embeddings[:,0::2]=torch.sin(embeddings[:,0::2]) # dim 2i embeddings[:,1::2]=torch.cos(embeddings[:,1::2]) # dim 2i+1 return embeddings def print_log(file_path,*args): print(*args) if file_path is not None: with open(file_path, 'a') as f: print(*args,file=f) def show_config(cfg): print_log(cfg.log_path, '=====================Config=====================') for k,v in cfg.__dict__.items(): print_log(cfg.log_path, k,': ',v) print_log(cfg.log_path, '======================End=======================') def show_epoch_info(phase, log_path, info): print_log(log_path, '') if phase=='Test': print_log(log_path, '====> %s at epoch #%d'%(phase, info['epoch'])) else: print_log(log_path, '%s at epoch #%d'%(phase, info['epoch'])) print_log(log_path, 'Group Activity Accuracy: %.2f%%, Loss: %.5f, Using %.1f seconds'%( info['activities_acc'], info['loss'], info['time'])) def log_final_exp_result(log_path, data_path, exp_result): no_display_cfg=['num_workers', 'use_gpu', 'use_multi_gpu', 'device_list', 'batch_size_test', 'test_interval_epoch', 'train_random_seed', 'result_path', 'log_path', 'device'] with open(log_path, 'a') as f: print('', file=f) print('', file=f) print('', file=f) print('=====================Config=====================', file=f) for k,v in exp_result['cfg'].__dict__.items(): if k not in no_display_cfg: print( k,': ',v, file=f) print('=====================Result======================', file=f) print('Best result:', file=f) print(exp_result['best_result'], file=f) print('Cost total %.4f hours.'%(exp_result['total_time']), file=f) print('======================End=======================', file=f) data_dict=pickle.load(open(data_path, 'rb')) data_dict[exp_result['cfg'].exp_name]=exp_result pickle.dump(data_dict, open(data_path, 'wb')) class AverageMeter(object): """ Computes the average value """ def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class Timer(object): """ class to do timekeeping """ def __init__(self): self.last_time=time.time() def timeit(self): old_time=self.last_time self.last_time=time.time() return self.last_time-old_time ================================================ FILE: volleyball.py ================================================ import numpy as np import skimage.io import skimage.transform import torch import torchvision.transforms as transforms from torch.utils import data import torchvision.models as models from PIL import Image import random import sys """ Reference: https://github.com/cvlab-epfl/social-scene-understanding/blob/master/volleyball.py """ ACTIVITIES = ['r_set', 'r_spike', 'r-pass', 'r_winpoint', 'l_set', 'l-spike', 'l-pass', 'l_winpoint'] NUM_ACTIVITIES = 8 ACTIONS = ['blocking', 'digging', 'falling', 'jumping', 'moving', 'setting', 'spiking', 'standing', 'waiting'] NUM_ACTIONS = 9 def volley_read_annotations(path): """ reading annotations for the given sequence """ annotations = {} gact_to_id = {name: i for i, name in enumerate(ACTIVITIES)} act_to_id = {name: i for i, name in enumerate(ACTIONS)} with open(path) as f: for l in f.readlines(): values = l[:-1].split(' ') file_name = values[0] activity = gact_to_id[values[1]] values = values[2:] num_people = len(values) // 5 action_names = values[4::5] actions = [act_to_id[name] for name in action_names] def _read_bbox(xywh): x, y, w, h = map(int, xywh) return y, x, y+h, x+w bboxes = np.array([_read_bbox(values[i:i+4]) for i in range(0, 5*num_people, 5)]) fid = int(file_name.split('.')[0]) annotations[fid] = { 'file_name': file_name, 'group_activity': activity, 'actions': actions, 'bboxes': bboxes, } return annotations def volley_read_dataset(path, seqs): data = {} for sid in seqs: data[sid] = volley_read_annotations(path + '/%d/annotations.txt' % sid) return data def volley_all_frames(data): frames = [] for sid, anns in data.items(): for fid, ann in anns.items(): frames.append((sid, fid)) return frames def volley_random_frames(data, num_frames): frames = [] for sid in np.random.choice(list(data.keys()), num_frames): fid = int(np.random.choice(list(data[sid]), [])) frames.append((sid, fid)) return frames def volley_frames_around(frame, num_before=5, num_after=4): sid, src_fid = frame return [(sid, src_fid, fid) for fid in range(src_fid-num_before, src_fid+num_after+1)] def load_samples_sequence(anns,tracks,images_path,frames,image_size,num_boxes=12,): """ load samples of a bath Returns: pytorch tensors """ images, boxes, boxes_idx = [], [], [] activities, actions = [], [] for i, (sid, src_fid, fid) in enumerate(frames): #img=skimage.io.imread(images_path + '/%d/%d/%d.jpg' % (sid, src_fid, fid)) #img=skimage.transform.resize(img,(720, 1280),anti_aliasing=True) img = Image.open(images_path + '/%d/%d/%d.jpg' % (sid, src_fid, fid)) img=transforms.functional.resize(img,image_size) img=np.array(img) # H,W,3 -> 3,H,W img=img.transpose(2,0,1) images.append(img) boxes.append(tracks[(sid, src_fid)][fid]) actions.append(anns[sid][src_fid]['actions']) if len(boxes[-1]) != num_boxes: boxes[-1] = np.vstack([boxes[-1], boxes[-1][:num_boxes-len(boxes[-1])]]) actions[-1] = actions[-1] + actions[-1][:num_boxes-len(actions[-1])] boxes_idx.append(i * np.ones(num_boxes, dtype=np.int32)) activities.append(anns[sid][src_fid]['group_activity']) images = np.stack(images) activities = np.array(activities, dtype=np.int32) bboxes = np.vstack(boxes).reshape([-1, num_boxes, 4]) bboxes_idx = np.hstack(boxes_idx).reshape([-1, num_boxes]) actions = np.hstack(actions).reshape([-1, num_boxes]) #convert to pytorch tensor images=torch.from_numpy(images).float() bboxes=torch.from_numpy(bboxes).float() bboxes_idx=torch.from_numpy(bboxes_idx).int() actions=torch.from_numpy(actions).long() activities=torch.from_numpy(activities).long() return images, bboxes, bboxes_idx, actions, activities class VolleyballDataset(data.Dataset): """ Characterize volleyball dataset for pytorch """ def __init__(self,anns,tracks,frames,images_path,image_size,feature_size,num_boxes=12,num_before=4,num_after=4,is_training=True,is_finetune=False): self.anns=anns self.tracks=tracks self.frames=frames self.images_path=images_path self.image_size=image_size self.feature_size=feature_size self.num_boxes=num_boxes self.num_before=num_before self.num_after=num_after self.is_training=is_training self.is_finetune=is_finetune def __len__(self): """ Return the total number of samples """ return len(self.frames) def __getitem__(self,index): """ Generate one sample of the dataset """ select_frames=self.volley_frames_sample(self.frames[index]) sample=self.load_samples_sequence(select_frames) return sample def volley_frames_sample(self,frame): sid, src_fid = frame if self.is_finetune: if self.is_training: fid=random.randint(src_fid-self.num_before, src_fid+self.num_after) return [(sid, src_fid, fid)] else: return [(sid, src_fid, fid) for fid in range(src_fid-self.num_before, src_fid+self.num_after+1)] else: if self.is_training: sample_frames=random.sample(range(src_fid-self.num_before, src_fid+self.num_after+1), 3) return [(sid, src_fid, fid) for fid in sample_frames] else: return [(sid, src_fid, fid) for fid in [src_fid-3,src_fid,src_fid+3, src_fid-4,src_fid-1,src_fid+2, src_fid-2,src_fid+1,src_fid+4 ]] def load_samples_sequence(self,select_frames): """ load samples sequence Returns: pytorch tensors """ OH, OW=self.feature_size images, boxes = [], [] activities, actions = [], [] for i, (sid, src_fid, fid) in enumerate(select_frames): img = Image.open(self.images_path + '/%d/%d/%d.jpg' % (sid, src_fid, fid)) img=transforms.functional.resize(img,self.image_size) img=np.array(img) # H,W,3 -> 3,H,W img=img.transpose(2,0,1) images.append(img) temp_boxes=np.ones_like(self.tracks[(sid, src_fid)][fid]) for i,track in enumerate(self.tracks[(sid, src_fid)][fid]): y1,x1,y2,x2 = track w1,h1,w2,h2 = x1*OW, y1*OH, x2*OW, y2*OH temp_boxes[i]=np.array([w1,h1,w2,h2]) boxes.append(temp_boxes) actions.append(self.anns[sid][src_fid]['actions']) if len(boxes[-1]) != self.num_boxes: boxes[-1] = np.vstack([boxes[-1], boxes[-1][:self.num_boxes-len(boxes[-1])]]) actions[-1] = actions[-1] + actions[-1][:self.num_boxes-len(actions[-1])] activities.append(self.anns[sid][src_fid]['group_activity']) images = np.stack(images) activities = np.array(activities, dtype=np.int32) bboxes = np.vstack(boxes).reshape([-1, self.num_boxes, 4]) actions = np.hstack(actions).reshape([-1, self.num_boxes]) #convert to pytorch tensor images=torch.from_numpy(images).float() bboxes=torch.from_numpy(bboxes).float() actions=torch.from_numpy(actions).long() activities=torch.from_numpy(activities).long() return images, bboxes, actions, activities