[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 Irhum Shafkat\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# R2Plus1D-PyTorch\nPyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper \"A Closer Look at Spatiotemporal Convolutions for Action Recognition\"\n\nLink to original: [paper](https://arxiv.org/abs/1711.11248) and [code](https://github.com/facebookresearch/R2Plus1D)\n\n***NOTE: This repository has been archived, although forks and other work that extend on top of this remain welcome***\n\n## Requirements \n\nR2Plus1D-PyTorch has the following requirements\n\n* PyTorch 0.4 and dependencies\n* OpenCV (tested on 3.4.0.12)\n* tqdm (for progress bars)\n\n### About this repository\n\nThis repository consists of four python files:\n\n* `module.py` - Contains an implementation of the factored, R2Plus1D convolution the entire implementation is based around. It is designed to be a replacement for nn.Conv3D in the appropriate scenario\n* `network.py` - Uses `module.py` to build up the residual network described in the paper\n* `dataset.py` - Implements a PyTorch dataset, that can load videos with appropriate labels from a given directory.\n* `trainer.py` - A mildly modified version of the script from the PyTorch [tutorials](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) to train the model. Features saving and restoring capabilities. \n\n### Training on Kinetics-400/600\n\nThis repository does not include a crawler or downloader for the Kinetics-400/600 dataset, however, one can be found [here](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics). It is strongly recommended to downsample the videos prior to training (and not on the fly), using a tool such as ffmpeg. If using the crawler, this can be done by adding `\"-vf\", \"scale=172:128\"` to the ffmpeg command list in the download clip function.\n\n### Training in general\n\nThis repository is designed for the ResNet to be trained on any dataset of videos in general, using the VideoDataloader class from dataset.py . It expects the videos to be arranged in a directory -> [train/val] folders -> [class_label] folders (one for each class) -> videos (the files themselves). \n\nForks and fixes of this repo are highly welcome!\n"
  },
  {
    "path": "dataset.py",
    "content": "import os\nfrom pathlib import Path\n\nimport cv2\nimport numpy as np\nfrom torch.utils.data import DataLoader, Dataset\n\n\nclass VideoDataset(Dataset):\n    r\"\"\"A Dataset for a folder of videos. Expects the directory structure to be\n    directory->[train/val/test]->[class labels]->[videos]. Initializes with a list \n    of all file names, along with an array of labels, with label being automatically\n    inferred from the respective folder names.\n\n        Args:\n            directory (str): The path to the directory containing the train/val/test datasets\n            mode (str, optional): Determines which folder of the directory the dataset will read from. Defaults to 'train'. \n            clip_len (int, optional): Determines how many frames are there in each clip. Defaults to 8. \n        \"\"\"\n\n    def __init__(self, directory, mode='train', clip_len=8):\n        folder = Path(directory)/mode  # get the directory of the specified split\n\n        self.clip_len = clip_len\n\n        # the following three parameters are chosen as described in the paper section 4.1\n        self.resize_height = 128  \n        self.resize_width = 171\n        self.crop_size = 112\n\n        # obtain all the filenames of files inside all the class folders \n        # going through each class folder one at a time\n        self.fnames, labels = [], []\n        for label in sorted(os.listdir(folder)):\n            for fname in os.listdir(os.path.join(folder, label)):\n                self.fnames.append(os.path.join(folder, label, fname))\n                labels.append(label)     \n\n        # prepare a mapping between the label names (strings) and indices (ints)\n        self.label2index = {label:index for index, label in enumerate(sorted(set(labels)))} \n        # convert the list of label names into an array of label indices\n        self.label_array = np.array([self.label2index[label] for label in labels], dtype=int)        \n\n    def __getitem__(self, index):\n        # loading and preprocessing. TODO move them to transform classes\n        buffer = self.loadvideo(self.fnames[index])\n        buffer = self.crop(buffer, self.clip_len, self.crop_size)\n        buffer = self.normalize(buffer)\n\n        return buffer, self.label_array[index]    \n        \n        \n\n    def loadvideo(self, fname):\n        # initialize a VideoCapture object to read video data into a numpy array\n        capture = cv2.VideoCapture(fname)\n        frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))\n        frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))\n        frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))\n        # create a buffer. Must have dtype float, so it gets converted to a FloatTensor by Pytorch later\n        buffer = np.empty((frame_count, self.resize_height, self.resize_width, 3), np.dtype('float32'))\n\n        count = 0\n        retaining = True\n\n        # read in each frame, one at a time into the numpy buffer array\n        while (count < frame_count and retaining):\n            retaining, frame = capture.read()\n            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            # will resize frames if not already final size\n            # NOTE: strongly recommended to resize them during the download process. This script\n            # will process videos of any size, but will take longer the larger the video file.\n            if (frame_height != self.resize_height) or (frame_width != self.resize_width):\n                frame = cv2.resize(frame, (self.resize_width, self.resize_height))\n            buffer[count] = frame\n            count += 1\n\n        # release the VideoCapture once it is no longer needed\n        capture.release()\n\n        # convert from [D, H, W, C] format to [C, D, H, W] (what PyTorch uses)\n        # D = Depth (in this case, time), H = Height, W = Width, C = Channels\n        buffer = buffer.transpose((3, 0, 1, 2))\n\n        return buffer \n    \n    def crop(self, buffer, clip_len, crop_size):\n        # randomly select time index for temporal jittering\n        time_index = np.random.randint(buffer.shape[1] - clip_len)\n        # randomly select start indices in order to crop the video\n        height_index = np.random.randint(buffer.shape[2] - crop_size)\n        width_index = np.random.randint(buffer.shape[3] - crop_size)\n\n        # crop and jitter the video using indexing. The spatial crop is performed on \n        # the entire array, so each frame is cropped in the same location. The temporal\n        # jitter takes place via the selection of consecutive frames\n        buffer = buffer[:, time_index:time_index + clip_len,\n                        height_index:height_index + crop_size,\n                        width_index:width_index + crop_size]\n\n        return buffer                \n\n    def normalize(self, buffer):\n        # Normalize the buffer\n        # NOTE: Default values of RGB images normalization are used, as precomputed \n        # mean and std_dev values (akin to ImageNet) were unavailable for Kinetics. Feel \n        # free to push to and edit this section to replace them if found. \n        buffer = (buffer - 128)/128\n        return buffer\n\n    def __len__(self):\n        return len(self.fnames)\n\n\nclass VideoDataset1M(VideoDataset):\n    r\"\"\"Dataset that implements VideoDataset, and produces exactly 1M augmented\n    training samples every epoch.\n        \n        Args:\n            directory (str): The path to the directory containing the train/val/test datasets\n            mode (str, optional): Determines which folder of the directory the dataset will read from. Defaults to 'train'. \n            clip_len (int, optional): Determines how many frames are there in each clip. Defaults to 8. \n        \"\"\"\n    def __init__(self, directory, mode='train', clip_len=8):\n        # Initialize instance of original dataset class\n        super(VideoDataset1M, self).__init__(directory, mode, clip_len)\n\n    def __getitem__(self, index):\n        # if we are to have 1M samples on every pass, we need to shuffle\n        # the index to a number in the original range, or else we'll get an \n        # index error. This is a legitimate operation, as even with the same \n        # index being used multiple times, it'll be randomly cropped, and\n        # be temporally jitterred differently on each pass, properly\n        # augmenting the data. \n        index = np.random.randint(len(self.fnames))\n\n        buffer = self.loadvideo(self.fnames[index])\n        buffer = self.crop(buffer, self.clip_len, self.crop_size)\n        buffer = self.normalize(buffer)\n\n        return buffer, self.label_array[index]    \n\n    def __len__(self):\n        return 1000000  # manually set the length to 1 million"
  },
  {
    "path": "module.py",
    "content": "import math\n\nimport torch.nn as nn\nfrom torch.nn.modules.utils import _triple\n\n\nclass SpatioTemporalConv(nn.Module):\n    r\"\"\"Applies a factored 3D convolution over an input signal composed of several input \n    planes with distinct spatial and time axes, by performing a 2D convolution over the \n    spatial axes to an intermediate subspace, followed by a 1D convolution over the time \n    axis to produce the final output.\n\n    Args:\n        in_channels (int): Number of channels in the input tensor\n        out_channels (int): Number of channels produced by the convolution\n        kernel_size (int or tuple): Size of the convolving kernel\n        stride (int or tuple, optional): Stride of the convolution. Default: 1\n        padding (int or tuple, optional): Zero-padding added to the sides of the input during their respective convolutions. Default: 0\n        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):\n        super(SpatioTemporalConv, self).__init__()\n\n        # if ints are entered, convert them to iterables, 1 -> [1, 1, 1]\n        kernel_size = _triple(kernel_size)\n        stride = _triple(stride)\n        padding = _triple(padding)\n\n        # decomposing the parameters into spatial and temporal components by\n        # masking out the values with the defaults on the axis that\n        # won't be convolved over. This is necessary to avoid unintentional\n        # behavior such as padding being added twice\n        spatial_kernel_size =  [1, kernel_size[1], kernel_size[2]]\n        spatial_stride =  [1, stride[1], stride[2]]\n        spatial_padding =  [0, padding[1], padding[2]]\n\n        temporal_kernel_size = [kernel_size[0], 1, 1]\n        temporal_stride =  [stride[0], 1, 1]\n        temporal_padding =  [padding[0], 0, 0]\n\n        # compute the number of intermediary channels (M) using formula \n        # from the paper section 3.5\n        intermed_channels = int(math.floor((kernel_size[0] * kernel_size[1] * kernel_size[2] * in_channels * out_channels)/ \\\n                            (kernel_size[1]* kernel_size[2] * in_channels + kernel_size[0] * out_channels)))\n\n        # the spatial conv is effectively a 2D conv due to the \n        # spatial_kernel_size, followed by batch_norm and ReLU\n        self.spatial_conv = nn.Conv3d(in_channels, intermed_channels, spatial_kernel_size,\n                                    stride=spatial_stride, padding=spatial_padding, bias=bias)\n        self.bn = nn.BatchNorm3d(intermed_channels)\n        self.relu = nn.ReLU()\n\n        # the temporal conv is effectively a 1D conv, but has batch norm \n        # and ReLU added inside the model constructor, not here. This is an \n        # intentional design choice, to allow this module to externally act \n        # identical to a standard Conv3D, so it can be reused easily in any \n        # other codebase\n        self.temporal_conv = nn.Conv3d(intermed_channels, out_channels, temporal_kernel_size, \n                                    stride=temporal_stride, padding=temporal_padding, bias=bias)\n\n    def forward(self, x):\n        x = self.relu(self.bn(self.spatial_conv(x)))\n        x = self.temporal_conv(x)\n        return x\n"
  },
  {
    "path": "network.py",
    "content": "import torch.nn as nn\nfrom torch.nn.modules.utils import _triple\n\nfrom module import SpatioTemporalConv\n\n\nclass SpatioTemporalResBlock(nn.Module):\n    r\"\"\"Single block for the ResNet network. Uses SpatioTemporalConv in \n        the standard ResNet block layout (conv->batchnorm->ReLU->conv->batchnorm->sum->ReLU)\n        \n        Args:\n            in_channels (int): Number of channels in the input tensor.\n            out_channels (int): Number of channels in the output produced by the block.\n            kernel_size (int or tuple): Size of the convolving kernels.\n            downsample (bool, optional): If ``True``, the output size is to be smaller than the input. Default: ``False``\n        \"\"\"\n    def __init__(self, in_channels, out_channels, kernel_size, downsample=False):\n        super(SpatioTemporalResBlock, self).__init__()\n        \n        # If downsample == True, the first conv of the layer has stride = 2 \n        # to halve the residual output size, and the input x is passed \n        # through a seperate 1x1x1 conv with stride = 2 to also halve it.\n\n        # no pooling layers are used inside ResNet\n        self.downsample = downsample\n        \n        # to allow for SAME padding\n        padding = kernel_size//2\n\n        if self.downsample:\n            # downsample with stride =2 the input x\n            self.downsampleconv = SpatioTemporalConv(in_channels, out_channels, 1, stride=2)\n            self.downsamplebn = nn.BatchNorm3d(out_channels)\n\n            # downsample with stride = 2when producing the residual\n            self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding, stride=2)\n        else:\n            self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding)\n\n        self.bn1 = nn.BatchNorm3d(out_channels)\n        self.relu1 = nn.ReLU()\n\n        # standard conv->batchnorm->ReLU\n        self.conv2 = SpatioTemporalConv(out_channels, out_channels, kernel_size, padding=padding)\n        self.bn2 = nn.BatchNorm3d(out_channels)\n        self.outrelu = nn.ReLU()\n\n    def forward(self, x):\n        res = self.relu1(self.bn1(self.conv1(x)))    \n        res = self.bn2(self.conv2(res))\n\n        if self.downsample:\n            x = self.downsamplebn(self.downsampleconv(x))\n\n        return self.outrelu(x + res)\n\n\nclass SpatioTemporalResLayer(nn.Module):\n    r\"\"\"Forms a single layer of the ResNet network, with a number of repeating \n    blocks of same output size stacked on top of each other\n        \n        Args:\n            in_channels (int): Number of channels in the input tensor.\n            out_channels (int): Number of channels in the output produced by the layer.\n            kernel_size (int or tuple): Size of the convolving kernels.\n            layer_size (int): Number of blocks to be stacked to form the layer\n            block_type (Module, optional): Type of block that is to be used to form the layer. Default: SpatioTemporalResBlock. \n            downsample (bool, optional): If ``True``, the first block in layer will implement downsampling. Default: ``False``\n        \"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size, layer_size, block_type=SpatioTemporalResBlock, downsample=False):\n        \n        super(SpatioTemporalResLayer, self).__init__()\n\n        # implement the first block\n        self.block1 = block_type(in_channels, out_channels, kernel_size, downsample)\n\n        # prepare module list to hold all (layer_size - 1) blocks\n        self.blocks = nn.ModuleList([])\n        for i in range(layer_size - 1):\n            # all these blocks are identical, and have downsample = False by default\n            self.blocks += [block_type(out_channels, out_channels, kernel_size)]\n\n    def forward(self, x):\n        x = self.block1(x)\n        for block in self.blocks:\n            x = block(x)\n\n        return x\n\n\nclass R2Plus1DNet(nn.Module):\n    r\"\"\"Forms the overall ResNet feature extractor by initializng 5 layers, with the number of blocks in \n    each layer set by layer_sizes, and by performing a global average pool at the end producing a \n    512-dimensional vector for each element in the batch.\n        \n        Args:\n            layer_sizes (tuple): An iterable containing the number of blocks in each layer\n            block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock. \n        \"\"\"\n    def __init__(self, layer_sizes, block_type=SpatioTemporalResBlock):\n        super(R2Plus1DNet, self).__init__()\n\n        # first conv, with stride 1x2x2 and kernel size 3x7x7\n        self.conv1 = SpatioTemporalConv(3, 64, [3, 7, 7], stride=[1, 2, 2], padding=[1, 3, 3])\n        # output of conv2 is same size as of conv1, no downsampling needed. kernel_size 3x3x3\n        self.conv2 = SpatioTemporalResLayer(64, 64, 3, layer_sizes[0], block_type=block_type)\n        # each of the final three layers doubles num_channels, while performing downsampling \n        # inside the first block\n        self.conv3 = SpatioTemporalResLayer(64, 128, 3, layer_sizes[1], block_type=block_type, downsample=True)\n        self.conv4 = SpatioTemporalResLayer(128, 256, 3, layer_sizes[2], block_type=block_type, downsample=True)\n        self.conv5 = SpatioTemporalResLayer(256, 512, 3, layer_sizes[3], block_type=block_type, downsample=True)\n\n        # global average pooling of the output\n        self.pool = nn.AdaptiveAvgPool3d(1)\n    \n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.conv2(x)\n        x = self.conv3(x)\n        x = self.conv4(x)\n        x = self.conv5(x)\n\n        x = self.pool(x)\n        \n        return x.view(-1, 512)\n\nclass R2Plus1DClassifier(nn.Module):\n    r\"\"\"Forms a complete ResNet classifier producing vectors of size num_classes, by initializng 5 layers, \n    with the number of blocks in each layer set by layer_sizes, and by performing a global average pool\n    at the end producing a 512-dimensional vector for each element in the batch, \n    and passing them through a Linear layer.\n        \n        Args:\n            num_classes(int): Number of classes in the data\n            layer_sizes (tuple): An iterable containing the number of blocks in each layer\n            block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock. \n        \"\"\"\n    def __init__(self, num_classes, layer_sizes, block_type=SpatioTemporalResBlock):\n        super(R2Plus1DClassifier, self).__init__()\n\n        self.res2plus1d = R2Plus1DNet(layer_sizes, block_type)\n        self.linear = nn.Linear(512, num_classes)\n\n    def forward(self, x):\n        x = self.res2plus1d(x)\n        x = self.linear(x) \n\n        return x   \n"
  },
  {
    "path": "trainer.py",
    "content": "import os\nimport time\n\nimport numpy as np\nimport torch\nfrom torch import nn, optim\nfrom torch.utils.data import DataLoader\nfrom tqdm import tqdm\n\nfrom dataset import VideoDataset, VideoDataset1M\nfrom network import R2Plus1DClassifier\n\n# Use GPU if available else revert to CPU\ndevice = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\nprint(\"Device being used:\", device)\n\ndef train_model(num_classes, directory, layer_sizes=[2, 2, 2, 2], num_epochs=45, save=True, path=\"model_data.pth.tar\"):\n    \"\"\"Initalizes and the model for a fixed number of epochs, using dataloaders from the specified directory, \n    selected optimizer, scheduler, criterion, defualt otherwise. Features saving and restoration capabilities as well. \n    Adapted from the PyTorch tutorial found here: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html\n\n        Args:\n            num_classes (int): Number of classes in the data\n            directory (str): Directory where the data is to be loaded from\n            layer_sizes (list, optional): Number of blocks in each layer. Defaults to [2, 2, 2, 2], equivalent to ResNet18.\n            num_epochs (int, optional): Number of epochs to train for. Defaults to 45. \n            save (bool, optional): If true, the model will be saved to path. Defaults to True. \n            path (str, optional): The directory to load a model checkpoint from, and if save == True, save to. Defaults to \"model_data.pth.tar\".\n    \"\"\"\n\n\n    # initalize the ResNet 18 version of this model\n    model = R2Plus1DClassifier(num_classes=num_classes, layer_sizes=layer_sizes).to(device)\n\n    criterion = nn.CrossEntropyLoss() # standard crossentropy loss for classification\n    optimizer = optim.SGD(model.parameters(), lr=0.01)  # hyperparameters as given in paper sec 4.1\n    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # the scheduler divides the lr by 10 every 10 epochs\n\n    # prepare the dataloaders into a dict\n    train_dataloader = DataLoader(VideoDataset(directory), batch_size=10, shuffle=True, num_workers=4)\n    # IF training on Kinetics-600 and require exactly a million samples each epoch, \n    # import VideoDataset1M and uncomment the following\n    # train_dataloader = DataLoader(VideoDataset1M(directory), batch_size=32, num_workers=4)\n    val_dataloader = DataLoader(VideoDataset(directory, mode='val'), batch_size=14, num_workers=4)\n    dataloaders = {'train': train_dataloader, 'val': val_dataloader}\n\n    dataset_sizes = {x: len(dataloaders[x].dataset) for x in ['train', 'val']}\n\n    # saves the time the process was started, to compute total time at the end\n    start = time.time()\n    epoch_resume = 0\n\n    # check if there was a previously saved checkpoint\n    if os.path.exists(path):\n        # loads the checkpoint\n        checkpoint = torch.load(path)\n        print(\"Reloading from previously saved checkpoint\")\n\n        # restores the model and optimizer state_dicts\n        model.load_state_dict(checkpoint['state_dict'])\n        optimizer.load_state_dict(checkpoint['opt_dict'])\n        \n        # obtains the epoch the training is to resume from\n        epoch_resume = checkpoint[\"epoch\"]\n\n    for epoch in tqdm(range(epoch_resume, num_epochs), unit=\"epochs\", initial=epoch_resume, total=num_epochs):\n        # each epoch has a training and validation step, in that order\n        for phase in ['train', 'val']:\n\n            # reset the running loss and corrects\n            running_loss = 0.0\n            running_corrects = 0\n\n            # set model to train() or eval() mode depending on whether it is trained\n            # or being validated. Primarily affects layers such as BatchNorm or Dropout.\n            if phase == 'train':\n                # scheduler.step() is to be called once every epoch during training\n                scheduler.step()\n                model.train()\n            else:\n                model.eval()\n\n\n            for inputs, labels in dataloaders[phase]:\n                # move inputs and labels to the device the training is taking place on\n                inputs = inputs.to(device)\n                labels = labels.to(device)\n                optimizer.zero_grad()\n\n                # keep intermediate states iff backpropagation will be performed. If false, \n                # then all intermediate states will be thrown away during evaluation, to use\n                # the least amount of memory possible.\n                with torch.set_grad_enabled(phase=='train'):\n                    outputs = model(inputs)\n                    # we're interested in the indices on the max values, not the values themselves\n                    _, preds = torch.max(outputs, 1)  \n                    loss = criterion(outputs, labels)\n\n                    # Backpropagate and optimize iff in training mode, else there's no intermediate\n                    # values to backpropagate with and will throw an error.\n                    if phase == 'train':\n                        loss.backward()\n                        optimizer.step()   \n\n                running_loss += loss.item() * inputs.size(0)\n                running_corrects += torch.sum(preds == labels.data)\n\n            epoch_loss = running_loss / dataset_sizes[phase]\n            epoch_acc = running_corrects.double() / dataset_sizes[phase]\n\n            print(f\"{phase} Loss: {epoch_loss} Acc: {epoch_acc}\")\n\n    # save the model if save=True\n    if save:\n        torch.save({\n        'epoch': epoch + 1,\n        'state_dict': model.state_dict(),\n        'acc': epoch_acc,\n        'opt_dict': optimizer.state_dict(),\n        }, path)\n\n    # print the total time needed, HH:MM:SS format\n    time_elapsed = time.time() - start    \n    print(f\"Training complete in {time_elapsed//3600}h {(time_elapsed%3600)//60}m {time_elapsed %60}s\")\n"
  }
]