Full Code of cvlab-yonsei/MNAD for AI

master 4e108a898605 cached
14 files
58.8 KB
15.0k tokens
86 symbols
1 requests
Download .txt
Repository: cvlab-yonsei/MNAD
Branch: master
Commit: 4e108a898605
Files: 14
Total size: 58.8 KB

Directory structure:
gitextract_9fesjy9r/

├── Evaluate.py
├── MNAD_files/
│   └── style.css
├── README.md
├── Train.py
├── data/
│   ├── data_seqkey_all.py
│   ├── frame_labels_avenue.npy
│   ├── frame_labels_ped2.npy
│   └── frame_labels_shanghai.npy
├── model/
│   ├── Memory.py
│   ├── Reconstruction.py
│   ├── final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1.py
│   ├── memory_final_spatial_sumonly_weight_ranking_top1.py
│   └── utils.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: Evaluate.py
================================================
import numpy as np
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.nn.init as init
import torch.utils.data as data
import torch.utils.data.dataset as dataset
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as v_utils
import matplotlib.pyplot as plt
import cv2
import math
from collections import OrderedDict
import copy
import time
from model.utils import DataLoader
from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import *
from model.Reconstruction import *
from sklearn.metrics import roc_auc_score
from utils import *
import random
import glob

import argparse


parser = argparse.ArgumentParser(description="MNAD")
parser.add_argument('--gpus', nargs='+', type=str, help='gpus')
parser.add_argument('--batch_size', type=int, default=4, help='batch size for training')
parser.add_argument('--test_batch_size', type=int, default=1, help='batch size for test')
parser.add_argument('--h', type=int, default=256, help='height of input images')
parser.add_argument('--w', type=int, default=256, help='width of input images')
parser.add_argument('--c', type=int, default=3, help='channel of input images')
parser.add_argument('--method', type=str, default='pred', help='The target task for anoamly detection')
parser.add_argument('--t_length', type=int, default=5, help='length of the frame sequences')
parser.add_argument('--fdim', type=int, default=512, help='channel dimension of the features')
parser.add_argument('--mdim', type=int, default=512, help='channel dimension of the memory items')
parser.add_argument('--msize', type=int, default=10, help='number of the memory items')
parser.add_argument('--alpha', type=float, default=0.6, help='weight for the anomality score')
parser.add_argument('--th', type=float, default=0.01, help='threshold for test updating')
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for the train loader')
parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader')
parser.add_argument('--dataset_type', type=str, default='ped2', help='type of dataset: ped2, avenue, shanghai')
parser.add_argument('--dataset_path', type=str, default='./dataset', help='directory of data')
parser.add_argument('--model_dir', type=str, help='directory of model')
parser.add_argument('--m_items_dir', type=str, help='directory of model')

args = parser.parse_args()

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
if args.gpus is None:
    gpus = "0"
    os.environ["CUDA_VISIBLE_DEVICES"]= gpus
else:
    gpus = ""
    for i in range(len(args.gpus)):
        gpus = gpus + args.gpus[i] + ","
    os.environ["CUDA_VISIBLE_DEVICES"]= gpus[:-1]

torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance

test_folder = args.dataset_path+"/"+args.dataset_type+"/testing/frames"

# Loading dataset
test_dataset = DataLoader(test_folder, transforms.Compose([
             transforms.ToTensor(),            
             ]), resize_height=args.h, resize_width=args.w, time_step=args.t_length-1)

test_size = len(test_dataset)

test_batch = data.DataLoader(test_dataset, batch_size = args.test_batch_size, 
                             shuffle=False, num_workers=args.num_workers_test, drop_last=False)

loss_func_mse = nn.MSELoss(reduction='none')

# Loading the trained model
model = torch.load(args.model_dir)
model.cuda()
m_items = torch.load(args.m_items_dir)
labels = np.load('./data/frame_labels_'+args.dataset_type+'.npy')

videos = OrderedDict()
videos_list = sorted(glob.glob(os.path.join(test_folder, '*')))
for video in videos_list:
    video_name = video.split('/')[-1]
    videos[video_name] = {}
    videos[video_name]['path'] = video
    videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg'))
    videos[video_name]['frame'].sort()
    videos[video_name]['length'] = len(videos[video_name]['frame'])

labels_list = []
label_length = 0
psnr_list = {}
feature_distance_list = {}

print('Evaluation of', args.dataset_type)

# Setting for video anomaly detection
for video in sorted(videos_list):
    video_name = video.split('/')[-1]
    if args.method == 'pred':
        labels_list = np.append(labels_list, labels[0][4+label_length:videos[video_name]['length']+label_length])
    else:
        labels_list = np.append(labels_list, labels[0][label_length:videos[video_name]['length']+label_length])
    label_length += videos[video_name]['length']
    psnr_list[video_name] = []
    feature_distance_list[video_name] = []

label_length = 0
video_num = 0
label_length += videos[videos_list[video_num].split('/')[-1]]['length']
m_items_test = m_items.clone()

model.eval()

for k,(imgs) in enumerate(test_batch):
    
    if args.method == 'pred':
        if k == label_length-4*(video_num+1):
            video_num += 1
            label_length += videos[videos_list[video_num].split('/')[-1]]['length']
    else:
        if k == label_length:
            video_num += 1
            label_length += videos[videos_list[video_num].split('/')[-1]]['length']

    imgs = Variable(imgs).cuda()
    
    if args.method == 'pred':
        outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, _, _, _, compactness_loss = model.forward(imgs[:,0:3*4], m_items_test, False)
        mse_imgs = torch.mean(loss_func_mse((outputs[0]+1)/2, (imgs[0,3*4:]+1)/2)).item()
        mse_feas = compactness_loss.item()

        # Calculating the threshold for updating at the test time
        point_sc = point_score(outputs, imgs[:,3*4:])
    
    else:
        outputs, feas, updated_feas, m_items_test, softmax_score_query, softmax_score_memory, compactness_loss = model.forward(imgs, m_items_test, False)
        mse_imgs = torch.mean(loss_func_mse((outputs[0]+1)/2, (imgs[0]+1)/2)).item()
        mse_feas = compactness_loss.item()

        # Calculating the threshold for updating at the test time
        point_sc = point_score(outputs, imgs)

    if  point_sc < args.th:
        query = F.normalize(feas, dim=1)
        query = query.permute(0,2,3,1) # b X h X w X d
        m_items_test = model.memory.update(query, m_items_test, False)

    psnr_list[videos_list[video_num].split('/')[-1]].append(psnr(mse_imgs))
    feature_distance_list[videos_list[video_num].split('/')[-1]].append(mse_feas)


# Measuring the abnormality score and the AUC
anomaly_score_total_list = []
for video in sorted(videos_list):
    video_name = video.split('/')[-1]
    anomaly_score_total_list += score_sum(anomaly_score_list(psnr_list[video_name]), 
                                     anomaly_score_list_inv(feature_distance_list[video_name]), args.alpha)

anomaly_score_total_list = np.asarray(anomaly_score_total_list)

accuracy = AUC(anomaly_score_total_list, np.expand_dims(1-labels_list, 0))

print('The result of ', args.dataset_type)
print('AUC: ', accuracy*100, '%')


================================================
FILE: MNAD_files/style.css
================================================
/* Space out content a bit */

@import url('https://fonts.googleapis.com/css?family=Baloo|Bungee+Inline|Lato|Righteous|Shojumaru');

body {
  padding-top: 20px;
  padding-bottom: 20px;
  font-family: 'Lato', cursive;
  font-size: 14px;
}

/* Everything but the jumbotron gets side spacing for mobile first views */
.header,
.row,
.footer {
  padding-left: 15px;
  padding-right: 15px;
}

/* Custom page header */
.header {
  border-bottom: 1px solid #e5e5e5;
}
/* Make the masthead heading the same height as the navigation */
.header h1 {
  margin-top: 0;
  margin-bottom: 0;
  line-height: 40px;
  padding-bottom: 19px;
  font-size: 30px;
  font-weight: bold;
}
.header h3 {
  margin-top: 0;
  margin-bottom: 0;
  line-height: 40px;
  padding-bottom: 19px;
  font-size: 20px;
}
.header h4 {
  font-family: 'Baloo', cursive;
}

/* Custom page footer */
.footer {
  padding-top: 19px;
  color: #777;
  border-top: 1px solid #e5e5e5;
}

/* Customize container */
@media (min-width: 938px) {
  .container {
    max-width: 900px;
  }
}
.container-narrow > hr {
  margin: 20px 0;
}

/* Main marketing message and sign up button */
.container .jumbotron {
  text-align: center;
  border-bottom: 1px solid #e5e5e5;
  padding-left: 20px;
  padding: 30px;
}
.jumbotron .btn {
  font-size: 21px;
  padding: 14px 24px;
}

.row p + h3 {
  margin-top: 28px;
}

div.row h3 {
  padding-bottom: 5px;
  border-bottom: 1px solid #ccc;
}

/* Responsive: Portrait tablets and up */
@media screen and (min-width: 938px) {
  /* Remove the padding we set earlier */
  .header,
  .marketing,
  .footer {
    padding-left: 0;
    padding-right: 0;
  }
  /* Space out the masthead */
  .header {
    margin-bottom: 30px;
  }
  /* Remove the bottom border on the jumbotron for visual effect */
  .jumbotron {
    border-bottom: 0;
  }
}

.readme h1 {
  display: none;
}

.left_column{
  float:middle;
  
}

.right_column{
  float:middle;
  
}

================================================
FILE: README.md
================================================
# PyTorch implementation of "Learning Memory-guided Normality for Anomaly Detection"

<p align="center"><img src="./MNAD_files/overview.png" alt="no_image" width="40%" height="40%" /><img src="./MNAD_files/teaser.png" alt="no_image" width="60%" height="60%" /></p>
This is the implementation of the paper "Learning Memory-guided Normality for Anomaly Detection (CVPR 2020)".

For more information, checkout the project site [[website](https://cvlab.yonsei.ac.kr/projects/MNAD/)] and the paper [[PDF](http://openaccess.thecvf.com/content_CVPR_2020/papers/Park_Learning_Memory-Guided_Normality_for_Anomaly_Detection_CVPR_2020_paper.pdf)].

## Dependencies
* Python 3.6
* PyTorch 1.1.0
* Numpy
* Sklearn

## Datasets
* USCD Ped2 [[dataset](https://github.com/StevenLiuWen/ano_pred_cvpr2018)]
* CUHK Avenue [[dataset](https://github.com/StevenLiuWen/ano_pred_cvpr2018)]
* ShanghaiTech [[dataset](https://github.com/StevenLiuWen/ano_pred_cvpr2018)]

These datasets are from an official github of "Future Frame Prediction for Anomaly Detection - A New Baseline (CVPR 2018)".

Download the datasets into ``dataset`` folder, like ``./dataset/ped2/``

## Update
* 02/04/21: We uploaded the codes based on reconstruction method, and pretrained wieghts for Ped2 reconstruction, Avenue prediction and Avenue reconstruction.


## Training
* ~~The training and testing codes are based on prediction method~~
* Now you can implemnet the codes based on both prediction and reconstruction methods.
* The codes are basically based on the prediction method, and you can easily implement this as
```bash
git clone https://github.com/cvlab-yonsei/projects
cd projects/MNAD/code
python Train.py # for training
```
* You can freely define parameters with your own settings like
```bash
python Train.py --gpus 1 --dataset_path 'your_dataset_directory' --dataset_type avenue --exp_dir 'your_log_directory'
```
* For the reconstruction task, you need to newly set the parameters, *e.g,*, the target task, the weights of the losses and the number of the time sequence.
```bash
python Train.py --method recon --loss_compact 0.01 --loss_separate 0.01 --t_length 1 # for training
```

## Evaluation
* Test your own model
* Check your dataset_type (ped2, avenue or shanghai)
```bash
python Evaluate.py --dataset_type ped2 --model_dir your_model.pth --m_items_dir your_m_items.pt
```
* For the reconstruction task, you need to set the parameters as
```bash
python Evaluate.py --method recon --t_length 1 --alpha 0.7 --th 0.015 --dataset_type ped2 --model_dir your_model.pth --m_items_dir your_m_items.pt
```
* Test the model with our pre-trained model and memory items
```bash
python Evaluate.py --dataset_type ped2 --model_dir pretrained_model.pth --m_items_dir m_items.pt
```

## Pre-trained model and memory items

Will be released soon.
<!--
* Download our pre-trained model and memory items 
<br>[[Ped2 Prediction](https://drive.google.com/file/d/1NdsGKUPvdNNwsnWcMYeO44gX2h-oJlEn/view?usp=sharing)]
<br>[[Ped2 Reconstruction](https://drive.google.com/file/d/1HgntMYJd_Qn5L1wLnsz3xnbjGwbmd5uJ/view?usp=sharing)]
<br>[[Avenue Prediction](https://drive.google.com/file/d/1q7auxT21We9bg5ySsLP9HoqsxPATsd8K/view?usp=sharing)]
<br>[[Avenue Reconstruction](https://drive.google.com/file/d/1mFADg-97ZWXIvZ-tAcoN7hoCFHXMN7Gc/view?usp=sharing)]

* Note that, you need to set lambda and threshold to 0.7 and 0.015, respectively, for the reconstruction task. See more details in the paper.
-->

## Bibtex
```
@inproceedings{park2020learning,
  title={Learning Memory-guided Normality for Anomaly Detection},
  author={Park, Hyunjong and Noh, Jongyoun and Ham, Bumsub},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={14372--14381},
  year={2020}
}
```


================================================
FILE: Train.py
================================================
import numpy as np
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.nn.init as init
import torch.utils.data as data
import torch.utils.data.dataset as dataset
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as v_utils
import matplotlib.pyplot as plt
import cv2
import math
from collections import OrderedDict
import copy
import time
from model.utils import DataLoader
from sklearn.metrics import roc_auc_score
from utils import *
import random

import argparse


parser = argparse.ArgumentParser(description="MNAD")
parser.add_argument('--gpus', nargs='+', type=str, help='gpus')
parser.add_argument('--batch_size', type=int, default=4, help='batch size for training')
parser.add_argument('--test_batch_size', type=int, default=1, help='batch size for test')
parser.add_argument('--epochs', type=int, default=60, help='number of epochs for training')
parser.add_argument('--loss_compact', type=float, default=0.1, help='weight of the feature compactness loss')
parser.add_argument('--loss_separate', type=float, default=0.1, help='weight of the feature separateness loss')
parser.add_argument('--h', type=int, default=256, help='height of input images')
parser.add_argument('--w', type=int, default=256, help='width of input images')
parser.add_argument('--c', type=int, default=3, help='channel of input images')
parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate')
parser.add_argument('--method', type=str, default='pred', help='The target task for anoamly detection')
parser.add_argument('--t_length', type=int, default=5, help='length of the frame sequences')
parser.add_argument('--fdim', type=int, default=512, help='channel dimension of the features')
parser.add_argument('--mdim', type=int, default=512, help='channel dimension of the memory items')
parser.add_argument('--msize', type=int, default=10, help='number of the memory items')
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for the train loader')
parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader')
parser.add_argument('--dataset_type', type=str, default='ped2', help='type of dataset: ped2, avenue, shanghai')
parser.add_argument('--dataset_path', type=str, default='./dataset', help='directory of data')
parser.add_argument('--exp_dir', type=str, default='log', help='directory of log')

args = parser.parse_args()

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
if args.gpus is None:
    gpus = "0"
    os.environ["CUDA_VISIBLE_DEVICES"]= gpus
else:
    gpus = ""
    for i in range(len(args.gpus)):
        gpus = gpus + args.gpus[i] + ","
    os.environ["CUDA_VISIBLE_DEVICES"]= gpus[:-1]

torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance

train_folder = args.dataset_path+"/"+args.dataset_type+"/training/frames"
test_folder = args.dataset_path+"/"+args.dataset_type+"/testing/frames"

# Loading dataset
train_dataset = DataLoader(train_folder, transforms.Compose([
             transforms.ToTensor(),          
             ]), resize_height=args.h, resize_width=args.w, time_step=args.t_length-1)

test_dataset = DataLoader(test_folder, transforms.Compose([
             transforms.ToTensor(),            
             ]), resize_height=args.h, resize_width=args.w, time_step=args.t_length-1)

train_size = len(train_dataset)
test_size = len(test_dataset)

train_batch = data.DataLoader(train_dataset, batch_size = args.batch_size, 
                              shuffle=True, num_workers=args.num_workers, drop_last=True)
test_batch = data.DataLoader(test_dataset, batch_size = args.test_batch_size, 
                             shuffle=False, num_workers=args.num_workers_test, drop_last=False)


# Model setting
assert args.method == 'pred' or args.method == 'recon', 'Wrong task name'
if args.method == 'pred':
    from model.final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1 import *
    model = convAE(args.c, args.t_length, args.msize, args.fdim, args.mdim)
else:
    from model.Reconstruction import *
    model = convAE(args.c, memory_size = args.msize, feature_dim = args.fdim, key_dim = args.mdim)
params_encoder =  list(model.encoder.parameters()) 
params_decoder = list(model.decoder.parameters())
params = params_encoder + params_decoder
optimizer = torch.optim.Adam(params, lr = args.lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max =args.epochs)
model.cuda()


# Report the training process
log_dir = os.path.join('./exp', args.dataset_type, args.method, args.exp_dir)
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
orig_stdout = sys.stdout
f = open(os.path.join(log_dir, 'log.txt'),'w')
sys.stdout= f

loss_func_mse = nn.MSELoss(reduction='none')

# Training

m_items = F.normalize(torch.rand((args.msize, args.mdim), dtype=torch.float), dim=1).cuda() # Initialize the memory items

for epoch in range(args.epochs):
    labels_list = []
    model.train()
    
    start = time.time()
    for j,(imgs) in enumerate(train_batch):
        
        imgs = Variable(imgs).cuda()
        
        if args.method == 'pred':
            outputs, _, _, m_items, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss = model.forward(imgs[:,0:12], m_items, True)
        
        else:
            outputs, _, _, m_items, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss = model.forward(imgs, m_items, True)
        
        
        optimizer.zero_grad()
        if args.method == 'pred':
            loss_pixel = torch.mean(loss_func_mse(outputs, imgs[:,12:]))
        else:
            loss_pixel = torch.mean(loss_func_mse(outputs, imgs))
            
        loss = loss_pixel + args.loss_compact * compactness_loss + args.loss_separate * separateness_loss
        loss.backward(retain_graph=True)
        optimizer.step()
        
    scheduler.step()
    
    print('----------------------------------------')
    print('Epoch:', epoch+1)
    if args.method == 'pred':
        print('Loss: Prediction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}'.format(loss_pixel.item(), compactness_loss.item(), separateness_loss.item()))
    else:
        print('Loss: Reconstruction {:.6f}/ Compactness {:.6f}/ Separateness {:.6f}'.format(loss_pixel.item(), compactness_loss.item(), separateness_loss.item()))
    print('Memory_items:')
    print(m_items)
    print('----------------------------------------')
    
print('Training is finished')
# Save the model and the memory items
torch.save(model, os.path.join(log_dir, 'model.pth'))
torch.save(m_items, os.path.join(log_dir, 'keys.pt'))
    
sys.stdout = orig_stdout
f.close()





================================================
FILE: data/data_seqkey_all.py
================================================
import numpy as np
import os
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import os.path
import sys


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)
    
def make_dataset(dir, class_to_idx):
    frames = []
    print(sorted(class_to_idx.keys()))
    dir = os.path.expanduser(dir)
    for target in sorted(class_to_idx.keys()):
        print(target)
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
#         new_fnames = []
              
        for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
#                 fname = fname.split('.')[0]
#                 seq = fname.split('_')[0][1:]
#                 fname = fname.split('_')[1]
#                 fname = fname.zfill(4)
#                 new_fnames.append('V'+seq+'_'+fname+'.png')
                
                path = os.path.join(root, fname)
                frames.append(path)
       
    return frames


class DatasetFolder(data.Dataset):
   

    def __init__(self, root, loader=default_loader,transform=None, target_transform=None, length=5):
        classes, class_to_idx = self._find_classes(root)
        samples = make_dataset(root, class_to_idx)
        if len(samples) == 0:
            raise(RuntimeError("Found 0 files in subfolders of: " + root))
        
        self.root = root
        self.loader = loader
        self.length = length
#         self.stride = np.random.choice(3,1) + 1
        self.classes = classes
        self.class_to_idx = class_to_idx
#         self.samples_gt = samples[self.length:]
        self.samples = samples[:-(self.length-1)]
        
        self.samples_all = samples
        self.samples_pool = samples[1:] 
#         self.targets = [s[1] for s in samples]

        self.transform = transform
        self.target_transform = target_transform

    def _find_classes(self, dir):
        """
        Finds the class folders in a dataset.
        Args:
            dir (string): Root directory path.
        Returns:
            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
        Ensures:
            No class is a subdirectory of another.
        """
        if sys.version_info >= (3, 5):
            # Faster and available in Python 3.5 and above
            classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        else:
            classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (samples, gt(+length)) 
        
        """
        
        
        sample = []          
        
        path_start = self.samples[index]
        sample_start = self.loader(path_start)
        if self.transform is not None:
            sample_start = self.transform(sample_start)
       
       
        sample.append(sample_start) 
        
        for i in range(self.length - 1):
            path = self.samples_all[index + (i+1)]
            sample_immediate = self.loader(path)
            if self.transform is not None:
                sample_immediate = self.transform(sample_immediate)
             
            sample.append(sample_immediate)
        
        
#         path_gt = self.samples_gt[index]
#         sample_gt = self.loader(path_gt)
     
#         if self.transform is not None:
#             sample_gt = self.transform(sample_gt)
        
        sample_input = sample[0]
        for i in range(self.length-1):
            sample_input = torch.cat((sample_input,sample[i+1]), dim=0)

        return sample_input
    
    def _stride(self):
        
        stride = int(np.random.choice(3,1) + 1)
        #if stride != 1:
#             self.samples_gt = self.samples_all[self.length*stride:]
         #   self.samples = self.samples_all[:-(self.length*stride)]
        
        return stride

    def __len__(self):
        return len(self.samples)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str


IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']





class ImageFolder(DatasetFolder):
    
    
    
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader, length=5):
        super(ImageFolder, self).__init__(root, loader,
                                          transform=transform,
                                          target_transform=target_transform)
        self.imgs = self.samples
        




================================================
FILE: model/Memory.py
================================================
import torch
import torch.autograd as ag
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import functools
import random
from torch.nn import functional as F

def random_uniform(shape, low, high, cuda):
    x = torch.rand(*shape)
    result_cpu = (high - low) * x + low
    if cuda:
        return result_cpu.cuda()
    else:
        return result_cpu
    
def distance(a, b):
    return torch.sqrt(((a - b) ** 2).sum()).unsqueeze(0)

def distance_batch(a, b):
    bs, _ = a.shape
    result = distance(a[0], b)
    for i in range(bs-1):
        result = torch.cat((result, distance(a[i], b)), 0)
        
    return result

def multiply(x): #to flatten matrix into a vector 
    return functools.reduce(lambda x,y: x*y, x, 1)

def flatten(x):
    """ Flatten matrix into a vector """
    count = multiply(x.size())
    return x.resize_(count)

def index(batch_size, x):
    idx = torch.arange(0, batch_size).long() 
    idx = torch.unsqueeze(idx, -1)
    return torch.cat((idx, x), dim=1)

def MemoryLoss(memory):

    m, d = memory.size()
    memory_t = torch.t(memory)
    similarity = (torch.matmul(memory, memory_t))/2 + 1/2 # 30X30
    identity_mask = torch.eye(m).cuda()
    sim = torch.abs(similarity - identity_mask)
    
    return torch.sum(sim)/(m*(m-1))


class Memory(nn.Module):
    def __init__(self, memory_size, feature_dim, key_dim,  temp_update, temp_gather):
        super(Memory, self).__init__()
        # Constants
        self.memory_size = memory_size
        self.feature_dim = feature_dim
        self.key_dim = key_dim
        self.temp_update = temp_update
        self.temp_gather = temp_gather
        
    def hard_neg_mem(self, mem, i):
        similarity = torch.matmul(mem,torch.t(self.keys_var))
        similarity[:,i] = -1
        _, max_idx = torch.topk(similarity, 1, dim=1)
        
        
        return self.keys_var[max_idx]
    
    def random_pick_memory(self, mem, max_indices):
        
        m, d = mem.size()
        output = []
        for i in range(m):
            flattened_indices = (max_indices==i).nonzero()
            a, _ = flattened_indices.size()
            if a != 0:
                number = np.random.choice(a, 1)
                output.append(flattened_indices[number, 0])
            else:
                output.append(-1)
            
        return torch.tensor(output)
    
    def get_update_query(self, mem, max_indices, update_indices, score, query, train):
        
        m, d = mem.size()
        if train:
            query_update = torch.zeros((m,d)).cuda()
            random_update = torch.zeros((m,d)).cuda()
            for i in range(m):
                idx = torch.nonzero(max_indices.squeeze(1)==i)
                a, _ = idx.size()
                #ex = update_indices[0][i]
                if a != 0:
                    #random_idx = torch.randperm(a)[0]
                    #idx = idx[idx != ex]
#                     query_update[i] = torch.sum(query[idx].squeeze(1), dim=0)
                    query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
                    #random_update[i] = query[random_idx] * (score[random_idx,i] / torch.max(score[:,i]))
                else:
                    query_update[i] = 0 
                    #random_update[i] = 0
        
       
            return query_update 
    
        else:
            query_update = torch.zeros((m,d)).cuda()
            for i in range(m):
                idx = torch.nonzero(max_indices.squeeze(1)==i)
                a, _ = idx.size()
                #ex = update_indices[0][i]
                if a != 0:
                    #idx = idx[idx != ex]
                    query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
#                     query_update[i] = torch.sum(query[idx].squeeze(1), dim=0)
                else:
                    query_update[i] = 0 
            
            return query_update

    def get_score(self, mem, query):
        bs, h,w,d = query.size()
        m, d = mem.size()
        
        score = torch.matmul(query, torch.t(mem))# b X h X w X m
        score = score.view(bs*h*w, m)# (b X h X w) X m
        
        score_query = F.softmax(score, dim=0)
        score_memory = F.softmax(score,dim=1)
        
        return score_query, score_memory
    
    def forward(self, query, keys, train=True):

        batch_size, dims,h,w = query.size() # b X d X h X w
        query = F.normalize(query, dim=1)
        query = query.permute(0,2,3,1) # b X h X w X d
        
        #train
        if train:
            #gathering loss
            gathering_loss = self.gather_loss(query,keys, train)
            #spreading_loss
            spreading_loss = self.spread_loss(query, keys, train)
            # read
            updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
            #update
            updated_memory = self.update(query, keys, train)
            
            return updated_query, updated_memory, softmax_score_query, softmax_score_memory, gathering_loss, spreading_loss
        
        #test
        else:
            #gathering loss
            gathering_loss = self.gather_loss(query,keys, train)
            
            # read
            updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
            
            #update
            updated_memory = keys
                
               
            return updated_query, updated_memory, softmax_score_query, softmax_score_memory, gathering_loss
        
        
    
    def update(self, query, keys,train):
        
        batch_size, h,w,dims = query.size() # b X h X w X d 
        
        softmax_score_query, softmax_score_memory = self.get_score(keys, query)
        
        query_reshape = query.contiguous().view(batch_size*h*w, dims)
        
        _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
        _, updating_indices = torch.topk(softmax_score_query, 1, dim=0)
        
        if train:
            # top-1 queries (of each memory) update (weighted sum) & random pick 
            query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape,train)
            updated_memory = F.normalize(query_update + keys, dim=1)
        
        else:
            # only weighted sum update when test 
            query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape, train)
            updated_memory = F.normalize(query_update + keys, dim=1)
        
        # top-1 update
        #query_update = query_reshape[updating_indices][0]
        #updated_memory = F.normalize(query_update + keys, dim=1)
      
        return updated_memory.detach()
        
        
    def pointwise_gather_loss(self, query_reshape, keys, gathering_indices, train):
        n,dims = query_reshape.size() # (b X h X w) X d
        loss_mse = torch.nn.MSELoss(reduction='none')
        
        pointwise_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())
                
        return pointwise_loss
        
    def spread_loss(self,query, keys, train):
        batch_size, h,w,dims = query.size() # b X h X w X d

        loss = torch.nn.TripletMarginLoss(margin=1.0)

        softmax_score_query, softmax_score_memory = self.get_score(keys, query)

        query_reshape = query.contiguous().view(batch_size*h*w, dims)

        _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1)

        #1st, 2nd closest memories
        pos = keys[gathering_indices[:,0]]
        neg = keys[gathering_indices[:,1]]

        spreading_loss = loss(query_reshape,pos.detach(), neg.detach())

        return spreading_loss
        
    def gather_loss(self, query, keys, train):
        
        batch_size, h,w,dims = query.size() # b X h X w X d

        loss_mse = torch.nn.MSELoss()

        softmax_score_query, softmax_score_memory = self.get_score(keys, query)

        query_reshape = query.contiguous().view(batch_size*h*w, dims)

        _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)

        gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())

        return gathering_loss
            
        
        
    
    def read(self, query, updated_memory):
        batch_size, h,w,dims = query.size() # b X h X w X d

        softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query)

        query_reshape = query.contiguous().view(batch_size*h*w, dims)
        
        concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) # (b X h X w) X d
        updated_query = torch.cat((query_reshape, concat_memory), dim = 1) # (b X h X w) X 2d
        updated_query = updated_query.view(batch_size, h, w, 2*dims)
        updated_query = updated_query.permute(0,3,1,2)
        
        return updated_query, softmax_score_query, softmax_score_memory
    
    

================================================
FILE: model/Reconstruction.py
================================================
import numpy as np
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from .Memory import *

class Encoder(torch.nn.Module):
    def __init__(self, t_length = 2, n_channel =3):
        super(Encoder, self).__init__()
        
        def Basic(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False)
            )
        
        def Basic_(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
            )
        
        self.moduleConv1 = Basic(n_channel*(t_length-1), 64)
        self.modulePool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.moduleConv2 = Basic(64, 128)
        self.modulePool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.moduleConv3 = Basic(128, 256)
        self.modulePool3 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.moduleConv4 = Basic_(256, 512)
        self.moduleBatchNorm = torch.nn.BatchNorm2d(512)
        self.moduleReLU = torch.nn.ReLU(inplace=False)
        
    def forward(self, x):

        tensorConv1 = self.moduleConv1(x)
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)
        
        return tensorConv4

    
    
class Decoder(torch.nn.Module):
    def __init__(self, t_length = 2, n_channel =3):
        super(Decoder, self).__init__()
        
        def Basic(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False)
            )
                
        
        def Gen(intInput, intOutput, nc):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=nc, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(nc),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(nc),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=nc, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.Tanh()
            )
        
        def Upsample(nc, intOutput):
            return torch.nn.Sequential(
                torch.nn.ConvTranspose2d(in_channels = nc, out_channels=intOutput, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False)
            )
      
        self.moduleConv = Basic(1024, 512)
        self.moduleUpsample4 = Upsample(512, 512)

        self.moduleDeconv3 = Basic(512, 256)
        self.moduleUpsample3 = Upsample(256, 256)

        self.moduleDeconv2 = Basic(256, 128)
        self.moduleUpsample2 = Upsample(128, 128)

        self.moduleDeconv1 = Gen(128,n_channel,64)
        
        
        
    def forward(self, x):
        
        tensorConv = self.moduleConv(x)

        tensorUpsample4 = self.moduleUpsample4(tensorConv)
        
        tensorDeconv3 = self.moduleDeconv3(tensorUpsample4)
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)
        
        tensorDeconv2 = self.moduleDeconv2(tensorUpsample3)
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)
        
        output = self.moduleDeconv1(tensorUpsample2)

                
        return output
    


class convAE(torch.nn.Module):
    def __init__(self, n_channel =3,  t_length = 2, memory_size = 10, feature_dim = 512, key_dim = 512, temp_update = 0.1, temp_gather=0.1):
        super(convAE, self).__init__()

        self.encoder = Encoder(t_length, n_channel)
        self.decoder = Decoder(t_length, n_channel)
        self.memory = Memory(memory_size,feature_dim, key_dim, temp_update, temp_gather)
       

    def forward(self, x, keys,train=True):

        fea = self.encoder(x)
        if train:
            updated_fea, keys, softmax_score_query, softmax_score_memory, gathering_loss, spreading_loss = self.memory(fea, keys, train)
            output = self.decoder(updated_fea)
            
            return output, fea, updated_fea, keys, softmax_score_query, softmax_score_memory, gathering_loss, spreading_loss
        
        #test
        else:
            updated_fea, keys, softmax_score_query, softmax_score_memory, gathering_loss = self.memory(fea, keys, train)
            output = self.decoder(updated_fea)
            
            return output, fea, updated_fea, keys, softmax_score_query, softmax_score_memory, gathering_loss
        
                                          



    

================================================
FILE: model/final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1.py
================================================
import numpy as np
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from .memory_final_spatial_sumonly_weight_ranking_top1 import *

class Encoder(torch.nn.Module):
    def __init__(self, t_length = 5, n_channel =3):
        super(Encoder, self).__init__()
        
        def Basic(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False)
            )
        
        def Basic_(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
            )
        
        self.moduleConv1 = Basic(n_channel*(t_length-1), 64)
        self.modulePool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.moduleConv2 = Basic(64, 128)
        self.modulePool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.moduleConv3 = Basic(128, 256)
        self.modulePool3 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.moduleConv4 = Basic_(256, 512)
        self.moduleBatchNorm = torch.nn.BatchNorm2d(512)
        self.moduleReLU = torch.nn.ReLU(inplace=False)
        
    def forward(self, x):

        tensorConv1 = self.moduleConv1(x)
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)
        
        return tensorConv4, tensorConv1, tensorConv2, tensorConv3

    
    
class Decoder(torch.nn.Module):
    def __init__(self, t_length = 5, n_channel =3):
        super(Decoder, self).__init__()
        
        def Basic(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False)
            )
                
        
        def Gen(intInput, intOutput, nc):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=nc, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(nc),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=1),
                torch.nn.BatchNorm2d(nc),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=nc, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.Tanh()
            )
        
        def Upsample(nc, intOutput):
            return torch.nn.Sequential(
                torch.nn.ConvTranspose2d(in_channels = nc, out_channels=intOutput, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
                torch.nn.BatchNorm2d(intOutput),
                torch.nn.ReLU(inplace=False)
            )
      
        self.moduleConv = Basic(1024, 512)
        self.moduleUpsample4 = Upsample(512, 256)

        self.moduleDeconv3 = Basic(512, 256)
        self.moduleUpsample3 = Upsample(256, 128)

        self.moduleDeconv2 = Basic(256, 128)
        self.moduleUpsample2 = Upsample(128, 64)

        self.moduleDeconv1 = Gen(128,n_channel,64)
        
        
        
    def forward(self, x, skip1, skip2, skip3):
        
        tensorConv = self.moduleConv(x)

        tensorUpsample4 = self.moduleUpsample4(tensorConv)
        cat4 = torch.cat((skip3, tensorUpsample4), dim = 1)
        
        tensorDeconv3 = self.moduleDeconv3(cat4)
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)
        cat3 = torch.cat((skip2, tensorUpsample3), dim = 1)
        
        tensorDeconv2 = self.moduleDeconv2(cat3)
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)
        cat2 = torch.cat((skip1, tensorUpsample2), dim = 1)
        
        output = self.moduleDeconv1(cat2)

                
        return output
    


class convAE(torch.nn.Module):
    def __init__(self, n_channel =3,  t_length = 5, memory_size = 10, feature_dim = 512, key_dim = 512, temp_update = 0.1, temp_gather=0.1):
        super(convAE, self).__init__()

        self.encoder = Encoder(t_length, n_channel)
        self.decoder = Decoder(t_length, n_channel)
        self.memory = Memory(memory_size,feature_dim, key_dim, temp_update, temp_gather)
       

    def forward(self, x, keys,train=True):

        fea, skip1, skip2, skip3 = self.encoder(x)
        if train:
            updated_fea, keys, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss = self.memory(fea, keys, train)
            output = self.decoder(updated_fea, skip1, skip2, skip3)
            
            return output, fea, updated_fea, keys, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss
        
        #test
        else:
            updated_fea, keys, softmax_score_query, softmax_score_memory,query, top1_keys, keys_ind, compactness_loss = self.memory(fea, keys, train)
            output = self.decoder(updated_fea, skip1, skip2, skip3)
            
            return output, fea, updated_fea, keys, softmax_score_query, softmax_score_memory, query, top1_keys, keys_ind, compactness_loss
        
                                          



    
    


================================================
FILE: model/memory_final_spatial_sumonly_weight_ranking_top1.py
================================================
import torch
import torch.autograd as ag
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import functools
import random
from torch.nn import functional as F

def random_uniform(shape, low, high, cuda):
    x = torch.rand(*shape)
    result_cpu = (high - low) * x + low
    if cuda:
        return result_cpu.cuda()
    else:
        return result_cpu
    
def distance(a, b):
    return torch.sqrt(((a - b) ** 2).sum()).unsqueeze(0)

def distance_batch(a, b):
    bs, _ = a.shape
    result = distance(a[0], b)
    for i in range(bs-1):
        result = torch.cat((result, distance(a[i], b)), 0)
        
    return result

def multiply(x): #to flatten matrix into a vector 
    return functools.reduce(lambda x,y: x*y, x, 1)

def flatten(x):
    """ Flatten matrix into a vector """
    count = multiply(x.size())
    return x.resize_(count)

def index(batch_size, x):
    idx = torch.arange(0, batch_size).long() 
    idx = torch.unsqueeze(idx, -1)
    return torch.cat((idx, x), dim=1)

def MemoryLoss(memory):

    m, d = memory.size()
    memory_t = torch.t(memory)
    similarity = (torch.matmul(memory, memory_t))/2 + 1/2 # 30X30
    identity_mask = torch.eye(m).cuda()
    sim = torch.abs(similarity - identity_mask)
    
    return torch.sum(sim)/(m*(m-1))


class Memory(nn.Module):
    def __init__(self, memory_size, feature_dim, key_dim,  temp_update, temp_gather):
        super(Memory, self).__init__()
        # Constants
        self.memory_size = memory_size
        self.feature_dim = feature_dim
        self.key_dim = key_dim
        self.temp_update = temp_update
        self.temp_gather = temp_gather
        
    def hard_neg_mem(self, mem, i):
        similarity = torch.matmul(mem,torch.t(self.keys_var))
        similarity[:,i] = -1
        _, max_idx = torch.topk(similarity, 1, dim=1)
        
        
        return self.keys_var[max_idx]
    
    def random_pick_memory(self, mem, max_indices):
        
        m, d = mem.size()
        output = []
        for i in range(m):
            flattened_indices = (max_indices==i).nonzero()
            a, _ = flattened_indices.size()
            if a != 0:
                number = np.random.choice(a, 1)
                output.append(flattened_indices[number, 0])
            else:
                output.append(-1)
            
        return torch.tensor(output)
    
    def get_update_query(self, mem, max_indices, update_indices, score, query, train):
        
        m, d = mem.size()
        if train:
            query_update = torch.zeros((m,d)).cuda()
            # random_update = torch.zeros((m,d)).cuda()
            for i in range(m):
                idx = torch.nonzero(max_indices.squeeze(1)==i)
                a, _ = idx.size()
                if a != 0:
                    query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
                else:
                    query_update[i] = 0 
        
       
            return query_update 
    
        else:
            query_update = torch.zeros((m,d)).cuda()
            for i in range(m):
                idx = torch.nonzero(max_indices.squeeze(1)==i)
                a, _ = idx.size()
                if a != 0:
                    query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
                else:
                    query_update[i] = 0 
            
            return query_update

    def get_score(self, mem, query):
        bs, h,w,d = query.size()
        m, d = mem.size()
        
        score = torch.matmul(query, torch.t(mem))# b X h X w X m
        score = score.view(bs*h*w, m)# (b X h X w) X m
        
        score_query = F.softmax(score, dim=0)
        score_memory = F.softmax(score,dim=1)
        
        return score_query, score_memory
    
    def forward(self, query, keys, train=True):

        batch_size, dims,h,w = query.size() # b X d X h X w
        query = F.normalize(query, dim=1)
        query = query.permute(0,2,3,1) # b X h X w X d
        
        #train
        if train:
            #losses
            separateness_loss, compactness_loss = self.gather_loss(query,keys, train)
            # read
            updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
            #update
            updated_memory = self.update(query, keys, train)
            
            return updated_query, updated_memory, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss
        
        #test
        else:
            # loss
            compactness_loss, query_re, top1_keys, keys_ind = self.gather_loss(query,keys, train)
            
            # read
            updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
            
            #update
            updated_memory = keys
                
               
            return updated_query, updated_memory, softmax_score_query, softmax_score_memory, query_re, top1_keys,keys_ind, compactness_loss
        
        
    
    def update(self, query, keys,train):
        
        batch_size, h,w,dims = query.size() # b X h X w X d 
        
        softmax_score_query, softmax_score_memory = self.get_score(keys, query)
        
        query_reshape = query.contiguous().view(batch_size*h*w, dims)
        
        _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
        _, updating_indices = torch.topk(softmax_score_query, 1, dim=0)
        
        if train:
             
            query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape,train)
            updated_memory = F.normalize(query_update + keys, dim=1)
        
        else:
            query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape, train)
            updated_memory = F.normalize(query_update + keys, dim=1)
        
        return updated_memory.detach()
        
        
    def pointwise_gather_loss(self, query_reshape, keys, gathering_indices, train):
        n,dims = query_reshape.size() # (b X h X w) X d
        loss_mse = torch.nn.MSELoss(reduction='none')
        
        pointwise_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())
                
        return pointwise_loss
        
    def gather_loss(self,query, keys, train):
        batch_size, h,w,dims = query.size() # b X h X w X d
        if train:
            loss = torch.nn.TripletMarginLoss(margin=1.0)
            loss_mse = torch.nn.MSELoss()
            softmax_score_query, softmax_score_memory = self.get_score(keys, query)
        
            query_reshape = query.contiguous().view(batch_size*h*w, dims)
        
            _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1)
        
            #1st, 2nd closest memories
            pos = keys[gathering_indices[:,0]]
            neg = keys[gathering_indices[:,1]]
            top1_loss = loss_mse(query_reshape, pos.detach())
            gathering_loss = loss(query_reshape,pos.detach(), neg.detach())
            
            return gathering_loss, top1_loss
        
            
        else:
            loss_mse = torch.nn.MSELoss()
        
            softmax_score_query, softmax_score_memory = self.get_score(keys, query)
        
            query_reshape = query.contiguous().view(batch_size*h*w, dims)
        
            _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
        
            gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach())
            
            return gathering_loss, query_reshape, keys[gathering_indices].squeeze(1).detach(), gathering_indices[:,0]
            
        
        
    
    def read(self, query, updated_memory):
        batch_size, h,w,dims = query.size() # b X h X w X d

        softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query)

        query_reshape = query.contiguous().view(batch_size*h*w, dims)
        
        concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) # (b X h X w) X d
        updated_query = torch.cat((query_reshape, concat_memory), dim = 1) # (b X h X w) X 2d
        updated_query = updated_query.view(batch_size, h, w, 2*dims)
        updated_query = updated_query.permute(0,3,1,2)
        
        return updated_query, softmax_score_query, softmax_score_memory
    
    


================================================
FILE: model/utils.py
================================================
import numpy as np
from collections import OrderedDict
import os
import glob
import cv2
import torch.utils.data as data


rng = np.random.RandomState(2020)

def np_load_frame(filename, resize_height, resize_width):
    """
    Load image path and convert it to numpy.ndarray. Notes that the color channels are BGR and the color space
    is normalized from [0, 255] to [-1, 1].

    :param filename: the full path of image
    :param resize_height: resized height
    :param resize_width: resized width
    :return: numpy.ndarray
    """
    image_decoded = cv2.imread(filename)
    image_resized = cv2.resize(image_decoded, (resize_width, resize_height))
    image_resized = image_resized.astype(dtype=np.float32)
    image_resized = (image_resized / 127.5) - 1.0
    return image_resized




class DataLoader(data.Dataset):
    def __init__(self, video_folder, transform, resize_height, resize_width, time_step=4, num_pred=1):
        self.dir = video_folder
        self.transform = transform
        self.videos = OrderedDict()
        self._resize_height = resize_height
        self._resize_width = resize_width
        self._time_step = time_step
        self._num_pred = num_pred
        self.setup()
        self.samples = self.get_all_samples()
        
        
    def setup(self):
        videos = glob.glob(os.path.join(self.dir, '*'))
        for video in sorted(videos):
            video_name = video.split('/')[-1]
            self.videos[video_name] = {}
            self.videos[video_name]['path'] = video
            self.videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg'))
            self.videos[video_name]['frame'].sort()
            self.videos[video_name]['length'] = len(self.videos[video_name]['frame'])
            
            
    def get_all_samples(self):
        frames = []
        videos = glob.glob(os.path.join(self.dir, '*'))
        for video in sorted(videos):
            video_name = video.split('/')[-1]
            for i in range(len(self.videos[video_name]['frame'])-self._time_step):
                frames.append(self.videos[video_name]['frame'][i])
                           
        return frames               
            
        
    def __getitem__(self, index):
        video_name = self.samples[index].split('/')[-2]
        frame_name = int(self.samples[index].split('/')[-1].split('.')[-2])
        
        batch = []
        for i in range(self._time_step+self._num_pred):
            image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)
            if self.transform is not None:
                batch.append(self.transform(image))

        return np.concatenate(batch, axis=0)
        
        
    def __len__(self):
        return len(self.samples)


================================================
FILE: utils.py
================================================
import numpy as np
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.utils as v_utils
import matplotlib.pyplot as plt
import cv2
import math
from collections import OrderedDict
import copy
import time
from sklearn.metrics import roc_auc_score

def rmse(predictions, targets):
    return np.sqrt(((predictions - targets) ** 2).mean())

def psnr(mse):

    return 10 * math.log10(1 / mse)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def normalize_img(img):

    img_re = copy.copy(img)
    
    img_re = (img_re - np.min(img_re)) / (np.max(img_re) - np.min(img_re))
    
    return img_re

def point_score(outputs, imgs):
    
    loss_func_mse = nn.MSELoss(reduction='none')
    error = loss_func_mse((outputs[0]+1)/2,(imgs[0]+1)/2)
    normal = (1-torch.exp(-error))
    score = (torch.sum(normal*loss_func_mse((outputs[0]+1)/2,(imgs[0]+1)/2)) / torch.sum(normal)).item()
    return score
    
def anomaly_score(psnr, max_psnr, min_psnr):
    return ((psnr - min_psnr) / (max_psnr-min_psnr))

def anomaly_score_inv(psnr, max_psnr, min_psnr):
    return (1.0 - ((psnr - min_psnr) / (max_psnr-min_psnr)))

def anomaly_score_list(psnr_list):
    anomaly_score_list = list()
    for i in range(len(psnr_list)):
        anomaly_score_list.append(anomaly_score(psnr_list[i], np.max(psnr_list), np.min(psnr_list)))
        
    return anomaly_score_list

def anomaly_score_list_inv(psnr_list):
    anomaly_score_list = list()
    for i in range(len(psnr_list)):
        anomaly_score_list.append(anomaly_score_inv(psnr_list[i], np.max(psnr_list), np.min(psnr_list)))
        
    return anomaly_score_list

def AUC(anomal_scores, labels):
    frame_auc = roc_auc_score(y_true=np.squeeze(labels, axis=0), y_score=np.squeeze(anomal_scores))
    return frame_auc

def score_sum(list1, list2, alpha):
    list_result = []
    for i in range(len(list1)):
        list_result.append((alpha*list1[i]+(1-alpha)*list2[i]))
        
    return list_result
Download .txt
gitextract_9fesjy9r/

├── Evaluate.py
├── MNAD_files/
│   └── style.css
├── README.md
├── Train.py
├── data/
│   ├── data_seqkey_all.py
│   ├── frame_labels_avenue.npy
│   ├── frame_labels_ped2.npy
│   └── frame_labels_shanghai.npy
├── model/
│   ├── Memory.py
│   ├── Reconstruction.py
│   ├── final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1.py
│   ├── memory_final_spatial_sumonly_weight_ranking_top1.py
│   └── utils.py
└── utils.py
Download .txt
SYMBOL INDEX (86 symbols across 7 files)

FILE: data/data_seqkey_all.py
  function pil_loader (line 12) | def pil_loader(path):
  function accimage_loader (line 19) | def accimage_loader(path):
  function default_loader (line 28) | def default_loader(path):
  function make_dataset (line 35) | def make_dataset(dir, class_to_idx):
  class DatasetFolder (line 60) | class DatasetFolder(data.Dataset):
    method __init__ (line 63) | def __init__(self, root, loader=default_loader,transform=None, target_...
    method _find_classes (line 85) | def _find_classes(self, dir):
    method __getitem__ (line 104) | def __getitem__(self, index):
    method _stride (line 145) | def _stride(self):
    method __len__ (line 154) | def __len__(self):
    method __repr__ (line 157) | def __repr__(self):
  class ImageFolder (line 174) | class ImageFolder(DatasetFolder):
    method __init__ (line 178) | def __init__(self, root, transform=None, target_transform=None,

FILE: model/Memory.py
  function random_uniform (line 11) | def random_uniform(shape, low, high, cuda):
  function distance (line 19) | def distance(a, b):
  function distance_batch (line 22) | def distance_batch(a, b):
  function multiply (line 30) | def multiply(x): #to flatten matrix into a vector
  function flatten (line 33) | def flatten(x):
  function index (line 38) | def index(batch_size, x):
  function MemoryLoss (line 43) | def MemoryLoss(memory):
  class Memory (line 54) | class Memory(nn.Module):
    method __init__ (line 55) | def __init__(self, memory_size, feature_dim, key_dim,  temp_update, te...
    method hard_neg_mem (line 64) | def hard_neg_mem(self, mem, i):
    method random_pick_memory (line 72) | def random_pick_memory(self, mem, max_indices):
    method get_update_query (line 87) | def get_update_query(self, mem, max_indices, update_indices, score, qu...
    method get_score (line 125) | def get_score(self, mem, query):
    method forward (line 137) | def forward(self, query, keys, train=True):
    method update (line 172) | def update(self, query, keys,train):
    method pointwise_gather_loss (line 200) | def pointwise_gather_loss(self, query_reshape, keys, gathering_indices...
    method spread_loss (line 208) | def spread_loss(self,query, keys, train):
    method gather_loss (line 227) | def gather_loss(self, query, keys, train):
    method read (line 246) | def read(self, query, updated_memory):

FILE: model/Reconstruction.py
  class Encoder (line 9) | class Encoder(torch.nn.Module):
    method __init__ (line 10) | def __init__(self, t_length = 2, n_channel =3):
    method forward (line 44) | def forward(self, x):
  class Decoder (line 61) | class Decoder(torch.nn.Module):
    method __init__ (line 62) | def __init__(self, t_length = 2, n_channel =3):
    method forward (line 108) | def forward(self, x):
  class convAE (line 127) | class convAE(torch.nn.Module):
    method __init__ (line 128) | def __init__(self, n_channel =3,  t_length = 2, memory_size = 10, feat...
    method forward (line 136) | def forward(self, x, keys,train=True):

FILE: model/final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1.py
  class Encoder (line 9) | class Encoder(torch.nn.Module):
    method __init__ (line 10) | def __init__(self, t_length = 5, n_channel =3):
    method forward (line 44) | def forward(self, x):
  class Decoder (line 61) | class Decoder(torch.nn.Module):
    method __init__ (line 62) | def __init__(self, t_length = 5, n_channel =3):
    method forward (line 108) | def forward(self, x, skip1, skip2, skip3):
  class convAE (line 130) | class convAE(torch.nn.Module):
    method __init__ (line 131) | def __init__(self, n_channel =3,  t_length = 5, memory_size = 10, feat...
    method forward (line 139) | def forward(self, x, keys,train=True):

FILE: model/memory_final_spatial_sumonly_weight_ranking_top1.py
  function random_uniform (line 11) | def random_uniform(shape, low, high, cuda):
  function distance (line 19) | def distance(a, b):
  function distance_batch (line 22) | def distance_batch(a, b):
  function multiply (line 30) | def multiply(x): #to flatten matrix into a vector
  function flatten (line 33) | def flatten(x):
  function index (line 38) | def index(batch_size, x):
  function MemoryLoss (line 43) | def MemoryLoss(memory):
  class Memory (line 54) | class Memory(nn.Module):
    method __init__ (line 55) | def __init__(self, memory_size, feature_dim, key_dim,  temp_update, te...
    method hard_neg_mem (line 64) | def hard_neg_mem(self, mem, i):
    method random_pick_memory (line 72) | def random_pick_memory(self, mem, max_indices):
    method get_update_query (line 87) | def get_update_query(self, mem, max_indices, update_indices, score, qu...
    method get_score (line 116) | def get_score(self, mem, query):
    method forward (line 128) | def forward(self, query, keys, train=True):
    method update (line 161) | def update(self, query, keys,train):
    method pointwise_gather_loss (line 184) | def pointwise_gather_loss(self, query_reshape, keys, gathering_indices...
    method gather_loss (line 192) | def gather_loss(self,query, keys, train):
    method read (line 228) | def read(self, query, updated_memory):

FILE: model/utils.py
  function np_load_frame (line 11) | def np_load_frame(filename, resize_height, resize_width):
  class DataLoader (line 30) | class DataLoader(data.Dataset):
    method __init__ (line 31) | def __init__(self, video_folder, transform, resize_height, resize_widt...
    method setup (line 43) | def setup(self):
    method get_all_samples (line 54) | def get_all_samples(self):
    method __getitem__ (line 65) | def __getitem__(self, index):
    method __len__ (line 78) | def __len__(self):

FILE: utils.py
  function rmse (line 17) | def rmse(predictions, targets):
  function psnr (line 20) | def psnr(mse):
  function get_lr (line 24) | def get_lr(optimizer):
  function normalize_img (line 29) | def normalize_img(img):
  function point_score (line 37) | def point_score(outputs, imgs):
  function anomaly_score (line 45) | def anomaly_score(psnr, max_psnr, min_psnr):
  function anomaly_score_inv (line 48) | def anomaly_score_inv(psnr, max_psnr, min_psnr):
  function anomaly_score_list (line 51) | def anomaly_score_list(psnr_list):
  function anomaly_score_list_inv (line 58) | def anomaly_score_list_inv(psnr_list):
  function AUC (line 65) | def AUC(anomal_scores, labels):
  function score_sum (line 69) | def score_sum(list1, list2, alpha):
Condensed preview — 14 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (63K chars).
[
  {
    "path": "Evaluate.py",
    "chars": 7054,
    "preview": "import numpy as np\nimport os\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch."
  },
  {
    "path": "MNAD_files/style.css",
    "chars": 1916,
    "preview": "/* Space out content a bit */\n\n@import url('https://fonts.googleapis.com/css?family=Baloo|Bungee+Inline|Lato|Righteous|S"
  },
  {
    "path": "README.md",
    "chars": 3772,
    "preview": "# PyTorch implementation of \"Learning Memory-guided Normality for Anomaly Detection\"\n\n<p align=\"center\"><img src=\"./MNAD"
  },
  {
    "path": "Train.py",
    "chars": 6878,
    "preview": "import numpy as np\nimport os\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch."
  },
  {
    "path": "data/data_seqkey_all.py",
    "chars": 5852,
    "preview": "import numpy as np\nimport os\nimport torch\nimport torch.utils.data as data\nimport torchvision.transforms as transforms\nfr"
  },
  {
    "path": "model/Memory.py",
    "chars": 9091,
    "preview": "import torch\nimport torch.autograd as ag\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport"
  },
  {
    "path": "model/Reconstruction.py",
    "chars": 5909,
    "preview": "import numpy as np\nimport os\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom .Memory "
  },
  {
    "path": "model/final_future_prediction_with_memory_spatial_sumonly_weight_ranking_top1.py",
    "chars": 6295,
    "preview": "import numpy as np\nimport os\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom .memory_"
  },
  {
    "path": "model/memory_final_spatial_sumonly_weight_ranking_top1.py",
    "chars": 8532,
    "preview": "import torch\nimport torch.autograd as ag\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport"
  },
  {
    "path": "model/utils.py",
    "chars": 2792,
    "preview": "import numpy as np\nfrom collections import OrderedDict\nimport os\nimport glob\nimport cv2\nimport torch.utils.data as data\n"
  },
  {
    "path": "utils.py",
    "chars": 2077,
    "preview": "import numpy as np\nimport os\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchv"
  }
]

// ... and 3 more files (download for full content)

About this extraction

This page contains the full source code of the cvlab-yonsei/MNAD GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 14 files (58.8 KB), approximately 15.0k tokens, and a symbol index with 86 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!