Full Code of guxinqian/Simple-CCReID for AI

main f773d013508a cached
45 files
209.0 KB
52.7k tokens
262 symbols
1 requests
Download .txt
Showing preview only (222K chars total). Download the full file or copy to clipboard to get everything.
Repository: guxinqian/Simple-CCReID
Branch: main
Commit: f773d013508a
Files: 45
Total size: 209.0 KB

Directory structure:
gitextract_4atmcr4t/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── c2dres50_ce_cal.yaml
│   ├── default_img.py
│   ├── default_vid.py
│   ├── res50_cels_cal.yaml
│   ├── res50_cels_cal_16x4.yaml
│   └── res50_cels_cal_tri_16x4.yaml
├── data/
│   ├── __init__.py
│   ├── dataloader.py
│   ├── dataset_loader.py
│   ├── datasets/
│   │   ├── ccvid.py
│   │   ├── deepchange.py
│   │   ├── last.py
│   │   ├── ltcc.py
│   │   ├── prcc.py
│   │   └── vcclothes.py
│   ├── img_transforms.py
│   ├── samplers.py
│   ├── spatial_transforms.py
│   └── temporal_transforms.py
├── losses/
│   ├── __init__.py
│   ├── arcface_loss.py
│   ├── circle_loss.py
│   ├── clothes_based_adversarial_loss.py
│   ├── contrastive_loss.py
│   ├── cosface_loss.py
│   ├── cross_entropy_loss_with_label_smooth.py
│   ├── gather.py
│   └── triplet_loss.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── classifier.py
│   ├── img_resnet.py
│   ├── utils/
│   │   ├── c3d_blocks.py
│   │   ├── inflate.py
│   │   ├── nonlocal_blocks.py
│   │   └── pooling.py
│   └── vid_resnet.py
├── script.sh
├── test.py
├── tools/
│   ├── eval_metrics.py
│   └── utils.py
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
### A Simple Codebase for Clothes-Changing Person Re-identification.
####  [Clothes-Changing Person Re-identification with RGB Modality Only (CVPR, 2022)](https://arxiv.org/abs/2204.06890)

#### Requirements
- Python 3.6
- Pytorch 1.6.0
- yacs
- apex

#### CCVID Dataset
- [[BaiduYun]](https://pan.baidu.com/s/1W9yjqxS9qxfPUSu76JpE1g) password: q0q2
- [[GoogleDrive]](https://drive.google.com/file/d/1vkZxm5v-aBXa_JEi23MMeW4DgisGtS4W/view?usp=sharing)

#### Get Started
- Replace `_C.DATA.ROOT` and `_C.OUTPUT` in `configs/default_img.py&default_vid.py`with your own `data path` and `output path`, respectively.
- Run `script.sh`


#### Citation

If you use our code/dataset in your research or wish to refer to the baseline results, please use the following BibTeX entry.
    
    @inproceedings{gu2022CAL,
        title={Clothes-Changing Person Re-identification with RGB Modality Only},
        author={Gu, Xinqian and Chang, Hong and Ma, Bingpeng and Bai, Shutao and Shan, Shiguang and Chen, Xilin},
        booktitle={CVPR},
        year={2022},
    }

#### Related Repos

- [Simple-ReID](https://github.com/guxinqian/Simple-ReID)
- [fast-reid](https://github.com/JDAI-CV/fast-reid)
- [deep-person-reid](https://github.com/KaiyangZhou/deep-person-reid)
- [Pytorch ReID](https://github.com/layumi/Person_reID_baseline_pytorch)



================================================
FILE: configs/c2dres50_ce_cal.yaml
================================================
MODEL:
  NAME: c2dres50
LOSS:
  CLA_LOSS: crossentropy
  CAL: cal
TAG: c2dres50-ce-cal

================================================
FILE: configs/default_img.py
================================================
import os
import yaml
from yacs.config import CfgNode as CN


_C = CN()
# -----------------------------------------------------------------------------
# Data settings
# -----------------------------------------------------------------------------
_C.DATA = CN()
# Root path for dataset directory
_C.DATA.ROOT = '/home/guxinqian/data'
# Dataset for evaluation
_C.DATA.DATASET = 'ltcc'
# Workers for dataloader
_C.DATA.NUM_WORKERS = 4
# Height of input image
_C.DATA.HEIGHT = 384
# Width of input image
_C.DATA.WIDTH = 192
# Batch size for training
_C.DATA.TRAIN_BATCH = 32
# Batch size for testing
_C.DATA.TEST_BATCH = 128
# The number of instances per identity for training sampler
_C.DATA.NUM_INSTANCES = 8
# -----------------------------------------------------------------------------
# Augmentation settings
# -----------------------------------------------------------------------------
_C.AUG = CN()
# Random crop prob
_C.AUG.RC_PROB = 0.5
# Random erase prob
_C.AUG.RE_PROB = 0.5
# Random flip prob
_C.AUG.RF_PROB = 0.5
# -----------------------------------------------------------------------------
# Model settings
# -----------------------------------------------------------------------------
_C.MODEL = CN()
# Model name
_C.MODEL.NAME = 'resnet50'
# The stride for laery4 in resnet
_C.MODEL.RES4_STRIDE = 1
# feature dim
_C.MODEL.FEATURE_DIM = 4096
# Model path for resuming
_C.MODEL.RESUME = ''
# Global pooling after the backbone
_C.MODEL.POOLING = CN()
# Choose in ['avg', 'max', 'gem', 'maxavg']
_C.MODEL.POOLING.NAME = 'maxavg'
# Initialized power for GeM pooling
_C.MODEL.POOLING.P = 3
# -----------------------------------------------------------------------------
# Losses for training 
# -----------------------------------------------------------------------------
_C.LOSS = CN()
# Classification loss
_C.LOSS.CLA_LOSS = 'crossentropy'
# Clothes classification loss
_C.LOSS.CLOTHES_CLA_LOSS = 'cosface'
# Scale for classification loss
_C.LOSS.CLA_S = 16.
# Margin for classification loss
_C.LOSS.CLA_M = 0.
# Pairwise loss
_C.LOSS.PAIR_LOSS = 'triplet'
# The weight for pairwise loss
_C.LOSS.PAIR_LOSS_WEIGHT = 0.0
# Scale for pairwise loss
_C.LOSS.PAIR_S = 16.
# Margin for pairwise loss
_C.LOSS.PAIR_M = 0.3
# Clothes-based adversarial loss
_C.LOSS.CAL = 'cal'
# Epsilon for clothes-based adversarial loss
_C.LOSS.EPSILON = 0.1
# Momentum for clothes-based adversarial loss with memory bank
_C.LOSS.MOMENTUM = 0.
# -----------------------------------------------------------------------------
# Training settings
# -----------------------------------------------------------------------------
_C.TRAIN = CN()
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.MAX_EPOCH = 60
# Start epoch for clothes classification
_C.TRAIN.START_EPOCH_CC = 25
# Start epoch for adversarial training
_C.TRAIN.START_EPOCH_ADV = 25
# Optimizer
_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.NAME = 'adam'
# Learning rate
_C.TRAIN.OPTIMIZER.LR = 0.00035
_C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4
# LR scheduler
_C.TRAIN.LR_SCHEDULER = CN()
# Stepsize to decay learning rate
_C.TRAIN.LR_SCHEDULER.STEPSIZE = [20, 40]
# LR decay rate, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
# Using amp for training
_C.TRAIN.AMP = False
# -----------------------------------------------------------------------------
# Testing settings
# -----------------------------------------------------------------------------
_C.TEST = CN()
# Perform evaluation after every N epochs (set to -1 to test after training)
_C.TEST.EVAL_STEP = 5
# Start to evaluate after specific epoch
_C.TEST.START_EVAL = 0
# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
# Fixed random seed
_C.SEED = 1
# Perform evaluation only
_C.EVAL_MODE = False
# GPU device ids for CUDA_VISIBLE_DEVICES
_C.GPU = '0'
# Path to output folder, overwritten by command line argument
_C.OUTPUT = '/data/guxinqian/logs/'
# Tag of experiment, overwritten by command line argument
_C.TAG = 'res50-ce-cal'


def update_config(config, args):
    config.defrost()
    config.merge_from_file(args.cfg)

    # merge from specific arguments
    if args.root:
        config.DATA.ROOT = args.root
    if args.output:
        config.OUTPUT = args.output

    if args.resume:
        config.MODEL.RESUME = args.resume
    if args.eval:
        config.EVAL_MODE = True
    
    if args.tag:
        config.TAG = args.tag

    if args.dataset:
        config.DATA.DATASET = args.dataset
    if args.gpu:
        config.GPU = args.gpu
    if args.amp:
        config.TRAIN.AMP = True

    # output folder
    config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG)

    config.freeze()


def get_img_config(args):
    """Get a yacs CfgNode object with default values."""
    config = _C.clone()
    update_config(config, args)

    return config


================================================
FILE: configs/default_vid.py
================================================
import os
import yaml
from yacs.config import CfgNode as CN


_C = CN()
# -----------------------------------------------------------------------------
# Data settings
# -----------------------------------------------------------------------------
_C.DATA = CN()
# Root path for dataset directory
_C.DATA.ROOT = '/home/guxinqian/data'
# Dataset for evaluation
_C.DATA.DATASET = 'ccvid'
# Whether split each full-length video in the training set into some clips
_C.DATA.DENSE_SAMPLING = True
# Sampling step of dense sampling for training set
_C.DATA.SAMPLING_STEP = 64
# Workers for dataloader
_C.DATA.NUM_WORKERS = 4
# Height of input image
_C.DATA.HEIGHT = 256
# Width of input image
_C.DATA.WIDTH = 128
# Batch size for training
_C.DATA.TRAIN_BATCH = 16
# Batch size for testing
_C.DATA.TEST_BATCH = 128
# The number of instances per identity for training sampler
_C.DATA.NUM_INSTANCES = 4
# -----------------------------------------------------------------------------
# Augmentation settings
# -----------------------------------------------------------------------------
_C.AUG = CN()
# Random erase prob
_C.AUG.RE_PROB = 0.0
# Temporal sampling mode for training, 'tsn' or 'stride'
_C.AUG.TEMPORAL_SAMPLING_MODE = 'stride'
# Sequence length of each input video clip
_C.AUG.SEQ_LEN = 8
# Sampling stride of each input video clip
_C.AUG.SAMPLING_STRIDE = 4
# -----------------------------------------------------------------------------
# Model settings
# -----------------------------------------------------------------------------
_C.MODEL = CN()
# Model name. All supported model can be seen in models/__init__.py
_C.MODEL.NAME = 'c2dres50'
# The stride for laery4 in resnet
_C.MODEL.RES4_STRIDE = 1
# feature dim
_C.MODEL.FEATURE_DIM = 2048
# Model path for resuming
_C.MODEL.RESUME = ''
# Params for AP3D
_C.MODEL.AP3D = CN()
# Temperature for APM
_C.MODEL.AP3D.TEMPERATURE = 4
# Contrastive attention
_C.MODEL.AP3D.CONTRACTIVE_ATT = True
# -----------------------------------------------------------------------------
# Losses for training 
# -----------------------------------------------------------------------------
_C.LOSS = CN()
# Classification loss
_C.LOSS.CLA_LOSS = 'crossentropy'
# Clothes classification loss
_C.LOSS.CLOTHES_CLA_LOSS = 'cosface'
# Scale for classification loss
_C.LOSS.CLA_S = 16.
# Margin for classification loss
_C.LOSS.CLA_M = 0.
# Pairwise loss
_C.LOSS.PAIR_LOSS = 'triplet'
# The weight for pairwise loss
_C.LOSS.PAIR_LOSS_WEIGHT = 0.0
# Scale for pairwise loss
_C.LOSS.PAIR_S = 16.
# Margin for pairwise loss
_C.LOSS.PAIR_M = 0.3
# Clothes-based adversarial loss
_C.LOSS.CAL = 'cal'
# Epsilon for clothes-based adversarial loss
_C.LOSS.EPSILON = 0.1
# Momentum for clothes-based adversarial loss with memory bank
_C.LOSS.MOMENTUM = 0.
# -----------------------------------------------------------------------------
# Training settings
# -----------------------------------------------------------------------------
_C.TRAIN = CN()
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.MAX_EPOCH = 150
# Start epoch for clothes classification
_C.TRAIN.START_EPOCH_CC = 50
# Start epoch for adversarial training
_C.TRAIN.START_EPOCH_ADV = 50
# Optimizer
_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.NAME = 'adam'
# Learning rate
_C.TRAIN.OPTIMIZER.LR = 0.00035
_C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4
# LR scheduler
_C.TRAIN.LR_SCHEDULER = CN()
# Stepsize to decay learning rate
_C.TRAIN.LR_SCHEDULER.STEPSIZE = [40, 80, 120]
# LR decay rate, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
# Using amp for training
_C.TRAIN.AMP = False
# -----------------------------------------------------------------------------
# Testing settings
# -----------------------------------------------------------------------------
_C.TEST = CN()
# Perform evaluation after every N epochs (set to -1 to test after training)
_C.TEST.EVAL_STEP = 10
# Start to evaluate after specific epoch
_C.TEST.START_EVAL = 0
# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
# Fixed random seed
_C.SEED = 1
# Perform evaluation only
_C.EVAL_MODE = False
# GPU device ids for CUDA_VISIBLE_DEVICES
_C.GPU = '0, 1'
# Path to output folder, overwritten by command line argument
_C.OUTPUT = '/data/guxinqian/logs/'
# Tag of experiment, overwritten by command line argument
_C.TAG = 'res50-ce-cal'


def update_config(config, args):
    config.defrost()
    config.merge_from_file(args.cfg)

    # merge from specific arguments
    if args.root:
        config.DATA.ROOT = args.root
    if args.output:
        config.OUTPUT = args.output

    if args.resume:
        config.MODEL.RESUME = args.resume
    if args.eval:
        config.EVAL_MODE = True
    
    if args.tag:
        config.TAG = args.tag

    if args.dataset:
        config.DATA.DATASET = args.dataset
    if args.gpu:
        config.GPU = args.gpu
    if args.amp:
        config.TRAIN.AMP = True

    # output folder
    config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG)

    config.freeze()


def get_vid_config(args):
    """Get a yacs CfgNode object with default values."""
    config = _C.clone()
    update_config(config, args)

    return config


================================================
FILE: configs/res50_cels_cal.yaml
================================================
MODEL:
  NAME: resnet50
LOSS:
  CLA_LOSS: crossentropylabelsmooth
  CAL: cal
TAG: res50-cels-cal

================================================
FILE: configs/res50_cels_cal_16x4.yaml
================================================
MODEL:
  NAME: resnet50
DATA:
  NUM_INSTANCES: 4
  TRAIN_BATCH: 32
LOSS:
  CLA_LOSS: crossentropylabelsmooth
  CAL: cal
TAG: res50-cels-cal-16x4

================================================
FILE: configs/res50_cels_cal_tri_16x4.yaml
================================================
MODEL:
  NAME: resnet50
DATA:
  NUM_INSTANCES: 4
  TRAIN_BATCH: 32
LOSS:
  CLA_LOSS: crossentropylabelsmooth
  PAIR_LOSS: triplet
  CAL: cal
  PAIR_M: 0.3
  PAIR_LOSS_WEIGHT: 1.0
TAG: res50-cels-cal-tri-16x4

================================================
FILE: data/__init__.py
================================================
import data.img_transforms as T
import data.spatial_transforms as ST
import data.temporal_transforms as TT
from torch.utils.data import DataLoader
from data.dataloader import DataLoaderX
from data.dataset_loader import ImageDataset, VideoDataset
from data.samplers import DistributedRandomIdentitySampler, DistributedInferenceSampler
from data.datasets.ltcc import LTCC
from data.datasets.prcc import PRCC
from data.datasets.last import LaST
from data.datasets.ccvid import CCVID
from data.datasets.deepchange import DeepChange
from data.datasets.vcclothes import VCClothes, VCClothesSameClothes, VCClothesClothesChanging


__factory = {
    'ltcc': LTCC,
    'prcc': PRCC,
    'vcclothes': VCClothes,
    'vcclothes_sc': VCClothesSameClothes,
    'vcclothes_cc': VCClothesClothesChanging,
    'last': LaST,
    'ccvid': CCVID,
    'deepchange': DeepChange,
}

VID_DATASET = ['ccvid']


def get_names():
    return list(__factory.keys())


def build_dataset(config):
    if config.DATA.DATASET not in __factory.keys():
        raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, __factory.keys()))

    if config.DATA.DATASET in VID_DATASET:
        dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT, 
                                                 sampling_step=config.DATA.SAMPLING_STEP,
                                                 seq_len=config.AUG.SEQ_LEN, 
                                                 stride=config.AUG.SAMPLING_STRIDE)
    else:
        dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT)

    return dataset


def build_img_transforms(config):
    transform_train = T.Compose([
        T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)),
        T.RandomCroping(p=config.AUG.RC_PROB),
        T.RandomHorizontalFlip(p=config.AUG.RF_PROB),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.RandomErasing(probability=config.AUG.RE_PROB)
    ])
    transform_test = T.Compose([
        T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return transform_train, transform_test


def build_vid_transforms(config):
    spatial_transform_train = ST.Compose([
        ST.Scale((config.DATA.HEIGHT, config.DATA.WIDTH), interpolation=3),
        ST.RandomHorizontalFlip(),
        ST.ToTensor(),
        ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ST.RandomErasing(height=config.DATA.HEIGHT, width=config.DATA.WIDTH, probability=config.AUG.RE_PROB)
    ])
    spatial_transform_test = ST.Compose([
        ST.Scale((config.DATA.HEIGHT, config.DATA.WIDTH), interpolation=3),
        ST.ToTensor(),
        ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    if config.AUG.TEMPORAL_SAMPLING_MODE == 'tsn':
        temporal_transform_train = TT.TemporalDivisionCrop(size=config.AUG.SEQ_LEN)
    elif config.AUG.TEMPORAL_SAMPLING_MODE == 'stride':
        temporal_transform_train = TT.TemporalRandomCrop(size=config.AUG.SEQ_LEN, 
                                                         stride=config.AUG.SAMPLING_STRIDE)
    else:
        raise KeyError("Invalid temporal sempling mode '{}'".format(config.AUG.TEMPORAL_SAMPLING_MODE))

    temporal_transform_test = None

    return spatial_transform_train, spatial_transform_test, temporal_transform_train, temporal_transform_test


def build_dataloader(config):
    dataset = build_dataset(config)
    # video dataset
    if config.DATA.DATASET in VID_DATASET:
        spatial_transform_train, spatial_transform_test, temporal_transform_train, temporal_transform_test = build_vid_transforms(config)

        if config.DATA.DENSE_SAMPLING:
            train_sampler = DistributedRandomIdentitySampler(dataset.train_dense, 
                                                             num_instances=config.DATA.NUM_INSTANCES, 
                                                             seed=config.SEED)
            # split each original training video into a series of short videos and sample one clip for each short video during training
            trainloader = DataLoaderX(
                dataset=VideoDataset(dataset.train_dense, spatial_transform_train, temporal_transform_train),
                sampler=train_sampler,
                batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS,
                pin_memory=True, drop_last=True)
        else:
            train_sampler = DistributedRandomIdentitySampler(dataset.train, 
                                                             num_instances=config.DATA.NUM_INSTANCES, 
                                                             seed=config.SEED)
            # sample one clip for each original training video during training
            trainloader = DataLoaderX(
                dataset=VideoDataset(dataset.train, spatial_transform_train, temporal_transform_train),
                sampler=train_sampler,
                batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS,
                pin_memory=True, drop_last=True)
        
        # split each original test video into a series of clips and use the averaged feature of all clips as its representation
        queryloader = DataLoaderX(
            dataset=VideoDataset(dataset.recombined_query, spatial_transform_test, temporal_transform_test),
            sampler=DistributedInferenceSampler(dataset.recombined_query),
            batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
            pin_memory=True, drop_last=False, shuffle=False)
        galleryloader = DataLoaderX(
            dataset=VideoDataset(dataset.recombined_gallery, spatial_transform_test, temporal_transform_test),
            sampler=DistributedInferenceSampler(dataset.recombined_gallery),
            batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
            pin_memory=True, drop_last=False, shuffle=False)

        return trainloader, queryloader, galleryloader, dataset, train_sampler
    # image dataset
    else:
        transform_train, transform_test = build_img_transforms(config)
        train_sampler = DistributedRandomIdentitySampler(dataset.train, 
                                                         num_instances=config.DATA.NUM_INSTANCES, 
                                                         seed=config.SEED)
        trainloader = DataLoaderX(dataset=ImageDataset(dataset.train, transform=transform_train),
                                 sampler=train_sampler,
                                 batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS,
                                 pin_memory=True, drop_last=True)

        galleryloader = DataLoaderX(dataset=ImageDataset(dataset.gallery, transform=transform_test),
                                   sampler=DistributedInferenceSampler(dataset.gallery),
                                   batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
                                   pin_memory=True, drop_last=False, shuffle=False)

        if config.DATA.DATASET == 'prcc':
            queryloader_same = DataLoaderX(dataset=ImageDataset(dataset.query_same, transform=transform_test),
                                     sampler=DistributedInferenceSampler(dataset.query_same),
                                     batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
                                     pin_memory=True, drop_last=False, shuffle=False)
            queryloader_diff = DataLoaderX(dataset=ImageDataset(dataset.query_diff, transform=transform_test),
                                     sampler=DistributedInferenceSampler(dataset.query_diff),
                                     batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
                                     pin_memory=True, drop_last=False, shuffle=False)

            return trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler
        else:
            queryloader = DataLoaderX(dataset=ImageDataset(dataset.query, transform=transform_test),
                                     sampler=DistributedInferenceSampler(dataset.query),
                                     batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,
                                     pin_memory=True, drop_last=False, shuffle=False)

            return trainloader, queryloader, galleryloader, dataset, train_sampler

    

    


================================================
FILE: data/dataloader.py
================================================
# refer to: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/data/data_utils.py

import torch
import threading
import queue
from torch.utils.data import DataLoader
from torch import distributed as dist


"""
#based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch
This is a single-function package that transforms arbitrary generator into a background-thead generator that 
prefetches several batches of data in a parallel background thead.

This is useful if you have a computationally heavy process (CPU or GPU) that 
iteratively processes minibatches from the generator while the generator 
consumes some other resource (disk IO / loading from database / more CPU if you have unused cores). 

By default these two processes will constantly wait for one another to finish. If you make generator work in 
prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time.
We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc.

Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb
This package contains this object
 - BackgroundGenerator(any_other_generator[,max_prefetch = something])
"""


class BackgroundGenerator(threading.Thread):
    """
    the usage is below
    >> for batch in BackgroundGenerator(my_minibatch_iterator):
    >>    doit()
    More details are written in the BackgroundGenerator doc
    >> help(BackgroundGenerator)
    """

    def __init__(self, generator, local_rank, max_prefetch=10):
        """
        This function transforms generator into a background-thead generator.
        :param generator: generator or genexp or any
        It can be used with any minibatch generator.

        It is quite lightweight, but not entirely weightless.
        Using global variables inside generator is not recommended (may raise GIL and zero-out the
        benefit of having a background thread.)
        The ideal use case is when everything it requires is store inside it and everything it
        outputs is passed through queue.

        There's no restriction on doing weird stuff, reading/writing files, retrieving
        URLs [or whatever] wlilst iterating.

        :param max_prefetch: defines, how many iterations (at most) can background generator keep
        stored at any moment of time.
        Whenever there's already max_prefetch batches stored in queue, the background process will halt until
        one of these batches is dequeued.

        !Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator!

        Setting max_prefetch to -1 lets it store as many batches as it can, which will work
        slightly (if any) faster, but will require storing
        all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size
        unless dequeued quickly enough.
        """
        super().__init__()
        self.queue = queue.Queue(max_prefetch)
        self.generator = generator
        self.local_rank = local_rank
        self.daemon = True
        self.exit_event = threading.Event()
        self.start()

    def run(self):
        torch.cuda.set_device(self.local_rank)
        for item in self.generator:
            if self.exit_event.is_set():
                break
            self.queue.put(item)
        self.queue.put(None)

    def next(self):
        next_item = self.queue.get()
        if next_item is None:
            raise StopIteration
        return next_item

    # Python 3 compatibility
    def __next__(self):
        return self.next()

    def __iter__(self):
        return self


class DataLoaderX(DataLoader):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        local_rank = dist.get_rank()
        self.stream = torch.cuda.Stream(local_rank)  # create a new cuda stream in each process
        self.local_rank = local_rank

    def __iter__(self):
        self.iter = super().__iter__()
        self.iter = BackgroundGenerator(self.iter, self.local_rank)
        self.preload()
        return self

    def _shutdown_background_thread(self):
        if not self.iter.is_alive():
            # avoid re-entrance or ill-conditioned thread state
            return

        # Set exit event to True for background threading stopping
        self.iter.exit_event.set()

        # Exhaust all remaining elements, so that the queue becomes empty,
        # and the thread should quit
        for _ in self.iter:
            pass

        # Waiting for background thread to quit
        self.iter.join()

    def preload(self):
        self.batch = next(self.iter, None)
        if self.batch is None:
            return None
        with torch.cuda.stream(self.stream):
            # if isinstance(self.batch[0], torch.Tensor):
            #     self.batch[0] = self.batch[0].to(device=self.local_rank, non_blocking=True)
            for k, v in enumerate(self.batch):
                if isinstance(self.batch[k], torch.Tensor):
                    self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)

    def __next__(self):
        torch.cuda.current_stream().wait_stream(
            self.stream
        )  # wait tensor to put on GPU
        batch = self.batch
        if batch is None:
            raise StopIteration
        self.preload()
        return batch

    # Signal for shutting down background thread
    def shutdown(self):
        # If the dataloader is to be freed, shutdown its BackgroundGenerator
        self._shutdown_background_thread()


================================================
FILE: data/dataset_loader.py
================================================
import torch
import functools
import os.path as osp
from PIL import Image
from torch.utils.data import Dataset


def read_image(img_path):
    """Keep reading image until succeed.
    This can avoid IOError incurred by heavy IO process."""
    got_img = False
    if not osp.exists(img_path):
        raise IOError("{} does not exist".format(img_path))
    while not got_img:
        try:
            img = Image.open(img_path).convert('RGB')
            got_img = True
        except IOError:
            print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
            pass
    return img


class ImageDataset(Dataset):
    """Image Person ReID Dataset"""
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img_path, pid, camid, clothes_id = self.dataset[index]
        img = read_image(img_path)
        if self.transform is not None:
            img = self.transform(img)
        return img, pid, camid, clothes_id


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    try:
        import accimage
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def get_default_image_loader():
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader
    else:
        return pil_loader


def image_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


def video_loader(img_paths, image_loader):
    video = []
    for image_path in img_paths:
        if osp.exists(image_path):
            video.append(image_loader(image_path))
        else:
            return video

    return video


def get_default_video_loader():
    image_loader = get_default_image_loader()
    return functools.partial(video_loader, image_loader=image_loader)


class VideoDataset(Dataset):
    """Video Person ReID Dataset.
    Note:
        Batch data has shape N x C x T x H x W
    Args:
        dataset (list): List with items (img_paths, pid, camid)
        temporal_transform (callable, optional): A function/transform that  takes in a list of frame indices
            and returns a transformed version
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an video given its path and frame indices.
    """

    def __init__(self, 
                 dataset, 
                 spatial_transform=None,
                 temporal_transform=None,
                 get_loader=get_default_video_loader,
                 cloth_changing=True):
        self.dataset = dataset
        self.spatial_transform = spatial_transform
        self.temporal_transform = temporal_transform
        self.loader = get_loader()
        self.cloth_changing = cloth_changing

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (clip, pid, camid) where pid is identity of the clip.
        """
        if self.cloth_changing:
            img_paths, pid, camid, clothes_id = self.dataset[index]
        else:
            img_paths, pid, camid = self.dataset[index]

        if self.temporal_transform is not None:
            img_paths = self.temporal_transform(img_paths)

        clip = self.loader(img_paths)

        if self.spatial_transform is not None:
            self.spatial_transform.randomize_parameters()
            clip = [self.spatial_transform(img) for img in clip]

        # trans T x C x H x W to C x T x H x W
        clip = torch.stack(clip, 0).permute(1, 0, 2, 3)

        if self.cloth_changing:
            return clip, pid, camid, clothes_id
        else:
            return clip, pid, camid

================================================
FILE: data/datasets/ccvid.py
================================================
import os
import re
import glob
import h5py
import random
import math
import logging
import numpy as np
import os.path as osp
from scipy.io import loadmat
from tools.utils import mkdir_if_missing, write_json, read_json


class CCVID(object):
    """ CCVID

    Reference:
        Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022.
    """
    def __init__(self, root='/data/datasets/', sampling_step=64, seq_len=16, stride=4, **kwargs):
        self.root = osp.join(root, 'CCVID')
        self.train_path = osp.join(self.root, 'train.txt')
        self.query_path = osp.join(self.root, 'query.txt')
        self.gallery_path = osp.join(self.root, 'gallery.txt')
        self._check_before_run()
 
        train, num_train_tracklets, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes, _ = \
            self._process_data(self.train_path, relabel=True)
        clothes2label = self._clothes2label_test(self.query_path, self.gallery_path)
        query, num_query_tracklets, num_query_pids, num_query_imgs, num_query_clothes, _, _ = \
            self._process_data(self.query_path, relabel=False, clothes2label=clothes2label)
        gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs, num_gallery_clothes, _, _ = \
            self._process_data(self.gallery_path, relabel=False, clothes2label=clothes2label)

        # slice each full-length video in the trainingset into more video clip
        train_dense = self._densesampling_for_trainingset(train, sampling_step)
        # In the test stage, each video sample is divided into a series of equilong video clips with a pre-defined stride.
        recombined_query, query_vid2clip_index = self._recombination_for_testset(query, seq_len=seq_len, stride=stride)
        recombined_gallery, gallery_vid2clip_index = self._recombination_for_testset(gallery, seq_len=seq_len, stride=stride)
       
        num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs 
        min_num = np.min(num_imgs_per_tracklet)
        max_num = np.max(num_imgs_per_tracklet)
        avg_num = np.mean(num_imgs_per_tracklet)

        num_total_pids = num_train_pids + num_gallery_pids
        num_total_clothes = num_train_clothes + len(clothes2label)
        num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets 

        logger = logging.getLogger('reid.dataset')
        logger.info("=> CCVID loaded")
        logger.info("Dataset statistics:")
        logger.info("  ---------------------------------------------")
        logger.info("  subset       | # ids | # tracklets | # clothes")
        logger.info("  ---------------------------------------------")
        logger.info("  train        | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_clothes))
        logger.info("  train_dense  | {:5d} | {:11d} | {:9d}".format(num_train_pids, len(train_dense), num_train_clothes))
        logger.info("  query        | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_clothes))
        logger.info("  gallery      | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_clothes))
        logger.info("  ---------------------------------------------")
        logger.info("  total        | {:5d} | {:11d} | {:9d}".format(num_total_pids, num_total_tracklets, num_total_clothes))
        logger.info("  number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num))
        logger.info("  ---------------------------------------------")

        self.train = train
        self.train_dense = train_dense
        self.query = query
        self.gallery = gallery

        self.recombined_query = recombined_query
        self.recombined_gallery = recombined_gallery
        self.query_vid2clip_index = query_vid2clip_index
        self.gallery_vid2clip_index = gallery_vid2clip_index

        self.num_train_pids = num_train_pids
        self.num_train_clothes = num_train_clothes
        self.pid2clothes = pid2clothes

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.root):
            raise RuntimeError("'{}' is not available".format(self.root))
        if not osp.exists(self.train_path):
            raise RuntimeError("'{}' is not available".format(self.train_path))
        if not osp.exists(self.query_path):
            raise RuntimeError("'{}' is not available".format(self.query_path))
        if not osp.exists(self.gallery_path):
            raise RuntimeError("'{}' is not available".format(self.gallery_path))

    def _clothes2label_test(self, query_path, gallery_path):
        pid_container = set()
        clothes_container = set()
        with open(query_path, 'r') as f:
            for line in f:
                new_line = line.rstrip()
                tracklet_path, pid, clothes_label = new_line.split()
                clothes = '{}_{}'.format(pid, clothes_label)
                pid_container.add(pid)
                clothes_container.add(clothes)
        with open(gallery_path, 'r') as f:
            for line in f:
                new_line = line.rstrip()
                tracklet_path, pid, clothes_label = new_line.split()
                clothes = '{}_{}'.format(pid, clothes_label)
                pid_container.add(pid)
                clothes_container.add(clothes)
        pid_container = sorted(pid_container)
        clothes_container = sorted(clothes_container)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}
        clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)}

        return clothes2label

    def _process_data(self, data_path, relabel=False, clothes2label=None):
        tracklet_path_list = []
        pid_container = set()
        clothes_container = set()
        with open(data_path, 'r') as f:
            for line in f:
                new_line = line.rstrip()
                tracklet_path, pid, clothes_label = new_line.split()
                tracklet_path_list.append((tracklet_path, pid, clothes_label))
                clothes = '{}_{}'.format(pid, clothes_label)
                pid_container.add(pid)
                clothes_container.add(clothes)
        pid_container = sorted(pid_container)
        clothes_container = sorted(clothes_container)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}
        if clothes2label is None:
            clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)}

        num_tracklets = len(tracklet_path_list)
        num_pids = len(pid_container)
        num_clothes = len(clothes_container)

        tracklets = []
        num_imgs_per_tracklet = []
        pid2clothes = np.zeros((num_pids, len(clothes2label)))

        for tracklet_path, pid, clothes_label in tracklet_path_list:
            img_paths = glob.glob(osp.join(self.root, tracklet_path, '*')) 
            img_paths.sort()

            clothes = '{}_{}'.format(pid, clothes_label)
            clothes_id = clothes2label[clothes]
            pid2clothes[pid2label[pid], clothes_id] = 1
            if relabel:
                pid = pid2label[pid]
            else:
                pid = int(pid)
            session = tracklet_path.split('/')[0]
            cam = tracklet_path.split('_')[1]
            if session == 'session3':
                camid = int(cam) + 12
            else:
                camid = int(cam)

            num_imgs_per_tracklet.append(len(img_paths))
            tracklets.append((img_paths, pid, camid, clothes_id))

        num_tracklets = len(tracklets)

        return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet, num_clothes, pid2clothes, clothes2label

    def _densesampling_for_trainingset(self, dataset, sampling_step=64):
        ''' Split all videos in training set into lots of clips for dense sampling.

        Args:
            dataset (list): input dataset, each video is organized as (img_paths, pid, camid, clothes_id)
            sampling_step (int): sampling step for dense sampling

        Returns:
            new_dataset (list): output dataset
        '''
        new_dataset = []
        for (img_paths, pid, camid, clothes_id) in dataset:
            if sampling_step != 0:
                num_sampling = len(img_paths)//sampling_step
                if num_sampling == 0:
                    new_dataset.append((img_paths, pid, camid, clothes_id))
                else:
                    for idx in range(num_sampling):
                        if idx == num_sampling - 1:
                            new_dataset.append((img_paths[idx*sampling_step:], pid, camid, clothes_id))
                        else:
                            new_dataset.append((img_paths[idx*sampling_step : (idx+1)*sampling_step], pid, camid, clothes_id))
            else:
                new_dataset.append((img_paths, pid, camid, clothes_id))

        return new_dataset

    def _recombination_for_testset(self, dataset, seq_len=16, stride=4):
        ''' Split all videos in test set into lots of equilong clips.

        Args:
            dataset (list): input dataset, each video is organized as (img_paths, pid, camid, clothes_id)
            seq_len (int): sequence length of each output clip
            stride (int): temporal sampling stride

        Returns:
            new_dataset (list): output dataset with lots of equilong clips
            vid2clip_index (list): a list contains the start and end clip index of each original video
        '''
        new_dataset = []
        vid2clip_index = np.zeros((len(dataset), 2), dtype=int)
        for idx, (img_paths, pid, camid, clothes_id) in enumerate(dataset):
            # start index
            vid2clip_index[idx, 0] = len(new_dataset)
            # process the sequence that can be divisible by seq_len*stride
            for i in range(len(img_paths)//(seq_len*stride)):
                for j in range(stride):
                    begin_idx = i * (seq_len * stride) + j
                    end_idx = (i + 1) * (seq_len * stride)
                    clip_paths = img_paths[begin_idx : end_idx : stride]
                    assert(len(clip_paths) == seq_len)
                    new_dataset.append((clip_paths, pid, camid, clothes_id))
            # process the remaining sequence that can't be divisible by seq_len*stride        
            if len(img_paths)%(seq_len*stride) != 0:
                # reducing stride
                new_stride = (len(img_paths)%(seq_len*stride)) // seq_len
                for i in range(new_stride):
                    begin_idx = len(img_paths) // (seq_len*stride) * (seq_len*stride) + i
                    end_idx = len(img_paths) // (seq_len*stride) * (seq_len*stride) + seq_len * new_stride
                    clip_paths = img_paths[begin_idx : end_idx : new_stride]
                    assert(len(clip_paths) == seq_len)
                    new_dataset.append((clip_paths, pid, camid, clothes_id))
                # process the remaining sequence that can't be divisible by seq_len
                if len(img_paths) % seq_len != 0:
                    clip_paths = img_paths[len(img_paths)//seq_len*seq_len:]
                    # loop padding
                    while len(clip_paths) < seq_len:
                        for index in clip_paths:
                            if len(clip_paths) >= seq_len:
                                break
                            clip_paths.append(index)
                    assert(len(clip_paths) == seq_len)
                    new_dataset.append((clip_paths, pid, camid, clothes_id))
            # end index
            vid2clip_index[idx, 1] = len(new_dataset)
            assert((vid2clip_index[idx, 1]-vid2clip_index[idx, 0]) == math.ceil(len(img_paths)/seq_len))

        return new_dataset, vid2clip_index.tolist()



================================================
FILE: data/datasets/deepchange.py
================================================
import os
import re
import glob
import h5py
import random
import math
import logging
import numpy as np
import os.path as osp
from scipy.io import loadmat
from tools.utils import mkdir_if_missing, write_json, read_json
     

class DeepChange(object):
    """ DeepChange

    Reference:
        Xu et al. DeepChange: A Long-Term Person Re-Identification Benchmark. arXiv:2105.14685, 2021.

    URL: https://github.com/PengBoXiangShang/deepchange
    """
    dataset_dir = 'DeepChangeDataset'
    def __init__(self, root='data', **kwargs):
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'train-set')
        self.train_list = osp.join(self.dataset_dir, 'train-set-bbox.txt')
        self.val_query_dir = osp.join(self.dataset_dir, 'val-set-query')
        self.val_query_list = osp.join(self.dataset_dir, 'val-set-query-bbox.txt')
        self.val_gallery_dir = osp.join(self.dataset_dir, 'val-set-gallery')
        self.val_gallery_list = osp.join(self.dataset_dir, 'val-set-gallery-bbox.txt')
        self.test_query_dir = osp.join(self.dataset_dir, 'test-set-query')
        self.test_query_list = osp.join(self.dataset_dir, 'test-set-query-bbox.txt')
        self.test_gallery_dir = osp.join(self.dataset_dir, 'test-set-gallery')
        self.test_gallery_list = osp.join(self.dataset_dir, 'test-set-gallery-bbox.txt')
        self._check_before_run()

        train_names = self._get_names(self.train_list)
        val_query_names = self._get_names(self.val_query_list)
        val_gallery_names = self._get_names(self.val_gallery_list)
        test_query_names = self._get_names(self.test_query_list)
        test_gallery_names = self._get_names(self.test_gallery_list)

        pid2label, clothes2label, pid2clothes = self.get_pid2label_and_clothes2label(train_names)
        train, num_train_pids, num_train_clothes = self._process_dir(self.train_dir, train_names, clothes2label, pid2label=pid2label)

        pid2label, clothes2label = self.get_pid2label_and_clothes2label(val_query_names, val_gallery_names)
        val_query, num_val_query_pids, num_val_query_clothes  = self._process_dir(self.val_query_dir, val_query_names, clothes2label)
        val_gallery, num_val_gallery_pids, num_val_gallery_clothes = self._process_dir(self.val_gallery_dir, val_gallery_names, clothes2label)
        num_val_pids = len(pid2label)
        num_val_clothes = len(clothes2label)

        pid2label, clothes2label = self.get_pid2label_and_clothes2label(test_query_names, test_gallery_names)
        test_query, num_test_query_pids, num_test_query_clothes = self._process_dir(self.test_query_dir, test_query_names, clothes2label)
        test_gallery, num_test_gallery_pids, num_test_gallery_clothes = self._process_dir(self.test_gallery_dir, test_gallery_names, clothes2label)
        num_test_pids = len(pid2label)
        num_test_clothes = len(clothes2label)

        num_total_pids = num_train_pids + num_val_pids + num_test_pids
        num_total_clothes = num_train_clothes + num_val_clothes + num_test_clothes
        num_total_imgs = len(train) + len(val_query) + len(val_gallery) + len(test_query) + len(test_gallery)

        logger = logging.getLogger('reid.dataset')
        logger.info("=> DeepChange loaded")
        logger.info("Dataset statistics:")
        logger.info("  --------------------------------------------")
        logger.info("  subset        | # ids | # images | # clothes")
        logger.info("  ----------------------------------------")
        logger.info("  train         | {:5d} | {:8d} | {:9d} ".format(num_train_pids, len(train), num_train_clothes))
        logger.info("  query(val)    | {:5d} | {:8d} | {:9d} ".format(num_val_query_pids, len(val_query), num_val_query_clothes))
        logger.info("  gallery(val)  | {:5d} | {:8d} | {:9d} ".format(num_val_gallery_pids, len(val_gallery), num_val_gallery_clothes))
        logger.info("  query         | {:5d} | {:8d} | {:9d} ".format(num_test_query_pids, len(test_query), num_test_query_clothes))
        logger.info("  gallery       | {:5d} | {:8d} | {:9d} ".format(num_test_gallery_pids, len(test_gallery), num_test_gallery_clothes))
        logger.info("  --------------------------------------------")
        logger.info("  total         | {:5d} | {:8d} | {:9d} ".format(num_total_pids, num_total_imgs, num_total_clothes))
        logger.info("  --------------------------------------------")

        self.train = train
        self.val_query = val_query
        self.val_gallery = val_gallery
        self.query = test_query
        self.gallery = test_gallery

        self.num_train_pids = num_train_pids
        self.num_train_clothes = num_train_clothes
        self.pid2clothes = pid2clothes

    def _get_names(self, fpath):
        names = []
        with open(fpath, 'r') as f:
            for line in f:
                new_line = line.rstrip()
                names.append(new_line)
        return names

    def get_pid2label_and_clothes2label(self, img_names1, img_names2=None):
        if img_names2 is not None:
            img_names = img_names1 + img_names2
        else:
            img_names = img_names1

        pid_container = set()
        clothes_container = set()
        for img_name in img_names:
            names = img_name.split('.')[0].split('_')
            clothes = names[0] + names[2]
            pid = int(names[0][1:])
            pid_container.add(pid)
            clothes_container.add(clothes)
        pid_container = sorted(pid_container)
        clothes_container = sorted(clothes_container)
        pid2label = {pid: label for label, pid in enumerate(pid_container)}
        clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)}

        if img_names2 is not None:
            return pid2label, clothes2label

        num_pids = len(pid_container)
        num_clothes = len(clothes_container)
        pid2clothes = np.zeros((num_pids, num_clothes))
        for img_name in img_names:
            names = img_name.split('.')[0].split('_')
            clothes = names[0] + names[2]
            pid = int(names[0][1:])
            pid = pid2label[pid]
            clothes_id = clothes2label[clothes]
            pid2clothes[pid, clothes_id] = 1

        return pid2label, clothes2label, pid2clothes

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.val_query_dir):
            raise RuntimeError("'{}' is not available".format(self.val_query_dir))
        if not osp.exists(self.val_gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.val_gallery_dir))
        if not osp.exists(self.test_query_dir):
            raise RuntimeError("'{}' is not available".format(self.test_query_dir))
        if not osp.exists(self.test_gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.test_gallery_dir))

    def _process_dir(self, home_dir, img_names, clothes2label, pid2label=None):
        dataset = []
        pid_container = set()
        clothes_container = set()
        for img_name in img_names:
            img_path = osp.join(home_dir, img_name.split(',')[0])
            names = img_name.split('.')[0].split('_')
            tracklet_id = int(img_name.split(',')[1])
            clothes = names[0] + names[2]
            clothes_id = clothes2label[clothes]
            clothes_container.add(clothes_id)
            pid = int(names[0][1:])
            pid_container.add(pid)
            camid = int(names[1][1:])
            if pid2label is not None:
                pid = pid2label[pid]
            # on DeepChange, we allow the true matches coming from the same camera 
            # but different tracklets as query following the original paper.
            # So we use tracklet_id to replace camid for each sample.
            dataset.append((img_path, pid, tracklet_id, clothes_id))
        num_pids = len(pid_container)
        num_clothes = len(clothes_container)

        return dataset, num_pids, num_clothes

================================================
FILE: data/datasets/last.py
================================================
import os
import re
import glob
import h5py
import random
import math
import logging
import numpy as np
import os.path as osp
from scipy.io import loadmat
from tools.utils import mkdir_if_missing, write_json, read_json


class LaST(object):
    """ LaST

    Reference:
        Shu et al. Large-Scale Spatio-Temporal Person Re-identification: Algorithm and Benchmark. arXiv:2105.15076, 2021.

    URL: https://github.com/shuxjweb/last

    Note that LaST does not provide the clothes label for val and test set.
    """
    dataset_dir = "last"
    def __init__(self, root='data', **kwargs):
        super(LaST, self).__init__()
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'train')
        self.val_query_dir = osp.join(self.dataset_dir, 'val', 'query')
        self.val_gallery_dir = osp.join(self.dataset_dir, 'val', 'gallery')
        self.test_query_dir = osp.join(self.dataset_dir, 'test', 'query')
        self.test_gallery_dir = osp.join(self.dataset_dir, 'test', 'gallery')
        self._check_before_run()

        pid2label, clothes2label, pid2clothes = self.get_pid2label_and_clothes2label(self.train_dir)

        train, num_train_pids = self._process_dir(self.train_dir, pid2label=pid2label, clothes2label=clothes2label, relabel=True)
        val_query, num_val_query_pids = self._process_dir(self.val_query_dir, relabel=False)
        val_gallery, num_val_gallery_pids = self._process_dir(self.val_gallery_dir, relabel=False, recam=len(val_query))
        test_query, num_test_query_pids = self._process_dir(self.test_query_dir, relabel=False)
        test_gallery, num_test_gallery_pids = self._process_dir(self.test_gallery_dir, relabel=False, recam=len(test_query))

        num_total_pids = num_train_pids+num_val_gallery_pids+num_test_gallery_pids
        num_total_imgs = len(train) + len(val_query) + len(val_gallery) + len(test_query) + len(test_gallery)

        logger = logging.getLogger('reid.dataset')
        logger.info("=> LaST loaded")
        logger.info("Dataset statistics:")
        logger.info("  --------------------------------------------")
        logger.info("  subset        | # ids | # images | # clothes")
        logger.info("  ----------------------------------------")
        logger.info("  train         | {:5d} | {:8d} | {:9d}".format(num_train_pids, len(train), len(clothes2label)))
        logger.info("  query(val)    | {:5d} | {:8d} |".format(num_val_query_pids, len(val_query)))
        logger.info("  gallery(val)  | {:5d} | {:8d} |".format(num_val_gallery_pids, len(val_gallery)))
        logger.info("  query         | {:5d} | {:8d} |".format(num_test_query_pids, len(test_query)))
        logger.info("  gallery       | {:5d} | {:8d} |".format(num_test_gallery_pids, len(test_gallery)))
        logger.info("  --------------------------------------------")
        logger.info("  total         | {:5d} | {:8d} | ".format(num_total_pids, num_total_imgs))
        logger.info("  --------------------------------------------")

        self.train = train
        self.val_query = val_query
        self.val_gallery = val_gallery
        self.query = test_query
        self.gallery = test_gallery

        self.num_train_pids = num_train_pids
        self.num_train_clothes = len(clothes2label)
        self.pid2clothes = pid2clothes

    def get_pid2label_and_clothes2label(self, dir_path):
        img_paths = glob.glob(osp.join(dir_path, '*/*.jpg'))            # [103367,]
        img_paths.sort()

        pid_container = set()
        clothes_container = set()
        for img_path in img_paths:
            names = osp.basename(img_path).split('.')[0].split('_')
            clothes = names[0] + '_' + names[-1]
            pid = int(names[0])
            pid_container.add(pid)
            clothes_container.add(clothes)
        pid_container = sorted(pid_container)
        clothes_container = sorted(clothes_container)
        pid2label = {pid: label for label, pid in enumerate(pid_container)}
        clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)}

        num_pids = len(pid_container)
        num_clothes = len(clothes_container)

        pid2clothes = np.zeros((num_pids, num_clothes))
        for img_path in img_paths:
            names = osp.basename(img_path).split('.')[0].split('_')
            clothes = names[0] + '_' + names[-1]
            pid = int(names[0])
            pid = pid2label[pid]
            clothes_id = clothes2label[clothes]
            pid2clothes[pid, clothes_id] = 1

        return pid2label, clothes2label, pid2clothes

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.val_query_dir):
            raise RuntimeError("'{}' is not available".format(self.val_query_dir))
        if not osp.exists(self.val_gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.val_gallery_dir))
        if not osp.exists(self.test_query_dir):
            raise RuntimeError("'{}' is not available".format(self.test_query_dir))
        if not osp.exists(self.test_gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.test_gallery_dir))

    def _process_dir(self, dir_path, pid2label=None, clothes2label=None, relabel=False, recam=0):
        if 'query' in dir_path:
            img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        else:
            img_paths = glob.glob(osp.join(dir_path, '*/*.jpg'))
        img_paths.sort()
        
        dataset = []
        pid_container = set()
        for ii, img_path in enumerate(img_paths):
            names = osp.basename(img_path).split('.')[0].split('_')
            clothes = names[0] + '_' + names[-1]
            pid = int(names[0])
            pid_container.add(pid)
            camid = int(recam + ii)
            if relabel and pid2label is not None:
                pid = pid2label[pid]
            if relabel and clothes2label is not None:
                clothes_id = clothes2label[clothes]
            else:
                clothes_id = pid
            dataset.append((img_path, pid, camid, clothes_id))
        num_pids = len(pid_container)

        return dataset, num_pids

================================================
FILE: data/datasets/ltcc.py
================================================
import os
import re
import glob
import h5py
import random
import math
import logging
import numpy as np
import os.path as osp
from scipy.io import loadmat
from tools.utils import mkdir_if_missing, write_json, read_json


class LTCC(object):
    """ LTCC

    Reference:
        Qian et al. Long-Term Cloth-Changing Person Re-identification. arXiv:2005.12633, 2020.

    URL: https://naiq.github.io/LTCC_Perosn_ReID.html#
    """
    dataset_dir = 'LTCC_ReID'
    def __init__(self, root='data', **kwargs):
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'test')
        self._check_before_run()

        train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \
            self._process_dir_train(self.train_dir)
        query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = \
            self._process_dir_test(self.query_dir, self.gallery_dir)
        num_total_pids = num_train_pids + num_test_pids
        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs
        num_test_imgs = num_query_imgs + num_gallery_imgs 
        num_total_clothes = num_train_clothes + num_test_clothes

        logger = logging.getLogger('reid.dataset')
        logger.info("=> LTCC loaded")
        logger.info("Dataset statistics:")
        logger.info("  ----------------------------------------")
        logger.info("  subset   | # ids | # images | # clothes")
        logger.info("  ----------------------------------------")
        logger.info("  train    | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes))
        logger.info("  test     | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes))
        logger.info("  query    | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs))
        logger.info("  gallery  | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs))
        logger.info("  ----------------------------------------")
        logger.info("  total    | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes))
        logger.info("  ----------------------------------------")

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_train_clothes = num_train_clothes
        self.pid2clothes = pid2clothes

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))

    def _process_dir_train(self, dir_path):
        img_paths = glob.glob(osp.join(dir_path, '*.png'))
        img_paths.sort()
        pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)')
        pattern2 = re.compile(r'(\w+)_c')

        pid_container = set()
        clothes_container = set()
        for img_path in img_paths:
            pid, _, _ = map(int, pattern1.search(img_path).groups())
            clothes_id = pattern2.search(img_path).group(1)
            pid_container.add(pid)
            clothes_container.add(clothes_id)
        pid_container = sorted(pid_container)
        clothes_container = sorted(clothes_container)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}
        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}

        num_pids = len(pid_container)
        num_clothes = len(clothes_container)

        dataset = []
        pid2clothes = np.zeros((num_pids, num_clothes))
        for img_path in img_paths:
            pid, _, camid = map(int, pattern1.search(img_path).groups())
            clothes = pattern2.search(img_path).group(1)
            camid -= 1 # index starts from 0
            pid = pid2label[pid]
            clothes_id = clothes2label[clothes]
            dataset.append((img_path, pid, camid, clothes_id))
            pid2clothes[pid, clothes_id] = 1
        
        num_imgs = len(dataset)

        return dataset, num_pids, num_imgs, num_clothes, pid2clothes

    def _process_dir_test(self, query_path, gallery_path):
        query_img_paths = glob.glob(osp.join(query_path, '*.png'))
        gallery_img_paths = glob.glob(osp.join(gallery_path, '*.png'))
        query_img_paths.sort()
        gallery_img_paths.sort()
        pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)')
        pattern2 = re.compile(r'(\w+)_c')

        pid_container = set()
        clothes_container = set()
        for img_path in query_img_paths:
            pid, _, _ = map(int, pattern1.search(img_path).groups())
            clothes_id = pattern2.search(img_path).group(1)
            pid_container.add(pid)
            clothes_container.add(clothes_id)
        for img_path in gallery_img_paths:
            pid, _, _ = map(int, pattern1.search(img_path).groups())
            clothes_id = pattern2.search(img_path).group(1)
            pid_container.add(pid)
            clothes_container.add(clothes_id)
        pid_container = sorted(pid_container)
        clothes_container = sorted(clothes_container)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}
        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}

        num_pids = len(pid_container)
        num_clothes = len(clothes_container)

        query_dataset = []
        gallery_dataset = []
        for img_path in query_img_paths:
            pid, _, camid = map(int, pattern1.search(img_path).groups())
            clothes_id = pattern2.search(img_path).group(1)
            camid -= 1 # index starts from 0
            clothes_id = clothes2label[clothes_id]
            query_dataset.append((img_path, pid, camid, clothes_id))

        for img_path in gallery_img_paths:
            pid, _, camid = map(int, pattern1.search(img_path).groups())
            clothes_id = pattern2.search(img_path).group(1)
            camid -= 1 # index starts from 0
            clothes_id = clothes2label[clothes_id]
            gallery_dataset.append((img_path, pid, camid, clothes_id))
        
        num_imgs_query = len(query_dataset)
        num_imgs_gallery = len(gallery_dataset)

        return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes



================================================
FILE: data/datasets/prcc.py
================================================
import os
import re
import glob
import h5py
import random
import math
import logging
import numpy as np
import os.path as osp
from scipy.io import loadmat
from tools.utils import mkdir_if_missing, write_json, read_json


class PRCC(object):
    """ PRCC

    Reference:
        Yang et al. Person Re-identification by Contour Sketch under Moderate Clothing Change. TPAMI, 2019.

    URL: https://drive.google.com/file/d/1yTYawRm4ap3M-j0PjLQJ--xmZHseFDLz/view
    """
    dataset_dir = 'prcc'
    def __init__(self, root='data', **kwargs):
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'rgb/train')
        self.val_dir = osp.join(self.dataset_dir, 'rgb/val')
        self.test_dir = osp.join(self.dataset_dir, 'rgb/test')
        self._check_before_run()

        train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \
            self._process_dir_train(self.train_dir)
        val, num_val_pids, num_val_imgs, num_val_clothes, _ = \
            self._process_dir_train(self.val_dir)

        query_same, query_diff, gallery, num_test_pids, \
            num_query_imgs_same, num_query_imgs_diff, num_gallery_imgs, \
            num_test_clothes, gallery_idx = self._process_dir_test(self.test_dir)

        num_total_pids = num_train_pids + num_test_pids
        num_test_imgs = num_query_imgs_same + num_query_imgs_diff + num_gallery_imgs
        num_total_imgs = num_train_imgs + num_val_imgs + num_test_imgs
        num_total_clothes = num_train_clothes + num_test_clothes

        logger = logging.getLogger('reid.dataset')
        logger.info("=> PRCC loaded")
        logger.info("Dataset statistics:")
        logger.info("  --------------------------------------------")
        logger.info("  subset      | # ids | # images | # clothes")
        logger.info("  --------------------------------------------")
        logger.info("  train       | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes))
        logger.info("  val         | {:5d} | {:8d} | {:9d}".format(num_val_pids, num_val_imgs, num_val_clothes))
        logger.info("  test        | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes))
        logger.info("  query(same) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_same))
        logger.info("  query(diff) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_diff))
        logger.info("  gallery     | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs))
        logger.info("  --------------------------------------------")
        logger.info("  total       | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes))
        logger.info("  --------------------------------------------")

        self.train = train
        self.val = val
        self.query_same = query_same
        self.query_diff = query_diff
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_train_clothes = num_train_clothes
        self.pid2clothes = pid2clothes
        self.gallery_idx = gallery_idx

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.val_dir):
            raise RuntimeError("'{}' is not available".format(self.val_dir))
        if not osp.exists(self.test_dir):
            raise RuntimeError("'{}' is not available".format(self.test_dir))

    def _process_dir_train(self, dir_path):
        pdirs = glob.glob(osp.join(dir_path, '*'))
        pdirs.sort()

        pid_container = set()
        clothes_container = set()
        for pdir in pdirs:
            pid = int(osp.basename(pdir))
            pid_container.add(pid)
            img_dirs = glob.glob(osp.join(pdir, '*.jpg'))
            for img_dir in img_dirs:
                cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C'
                if cam in ['A', 'B']:
                    clothes_container.add(osp.basename(pdir))
                else:
                    clothes_container.add(osp.basename(pdir)+osp.basename(img_dir)[0])
        pid_container = sorted(pid_container)
        clothes_container = sorted(clothes_container)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}
        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}
        cam2label = {'A': 0, 'B': 1, 'C': 2}

        num_pids = len(pid_container)
        num_clothes = len(clothes_container)

        dataset = []
        pid2clothes = np.zeros((num_pids, num_clothes))
        for pdir in pdirs:
            pid = int(osp.basename(pdir))
            img_dirs = glob.glob(osp.join(pdir, '*.jpg'))
            for img_dir in img_dirs:
                cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C'
                label = pid2label[pid]
                camid = cam2label[cam]
                if cam in ['A', 'B']:
                    clothes_id = clothes2label[osp.basename(pdir)]
                else:
                    clothes_id = clothes2label[osp.basename(pdir)+osp.basename(img_dir)[0]]
                dataset.append((img_dir, label, camid, clothes_id))
                pid2clothes[label, clothes_id] = 1            
        
        num_imgs = len(dataset)

        return dataset, num_pids, num_imgs, num_clothes, pid2clothes

    def _process_dir_test(self, test_path):
        pdirs = glob.glob(osp.join(test_path, '*'))
        pdirs.sort()

        pid_container = set()
        for pdir in glob.glob(osp.join(test_path, 'A', '*')):
            pid = int(osp.basename(pdir))
            pid_container.add(pid)
        pid_container = sorted(pid_container)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}
        cam2label = {'A': 0, 'B': 1, 'C': 2}

        num_pids = len(pid_container)
        num_clothes = num_pids * 2

        query_dataset_same_clothes = []
        query_dataset_diff_clothes = []
        gallery_dataset = []
        for cam in ['A', 'B', 'C']:
            pdirs = glob.glob(osp.join(test_path, cam, '*'))
            for pdir in pdirs:
                pid = int(osp.basename(pdir))
                img_dirs = glob.glob(osp.join(pdir, '*.jpg'))
                for img_dir in img_dirs:
                    # pid = pid2label[pid]
                    camid = cam2label[cam]
                    if cam == 'A':
                        clothes_id = pid2label[pid] * 2
                        gallery_dataset.append((img_dir, pid, camid, clothes_id))
                    elif cam == 'B':
                        clothes_id = pid2label[pid] * 2
                        query_dataset_same_clothes.append((img_dir, pid, camid, clothes_id))
                    else:
                        clothes_id = pid2label[pid] * 2 + 1
                        query_dataset_diff_clothes.append((img_dir, pid, camid, clothes_id))

        pid2imgidx = {}
        for idx, (img_dir, pid, camid, clothes_id) in enumerate(gallery_dataset):
            if pid not in pid2imgidx:
                pid2imgidx[pid] = []
            pid2imgidx[pid].append(idx)

        # get 10 gallery index to perform single-shot test
        gallery_idx = {}
        random.seed(3)
        for idx in range(0, 10):
            gallery_idx[idx] = []
            for pid in pid2imgidx:
                gallery_idx[idx].append(random.choice(pid2imgidx[pid]))
                 
        num_imgs_query_same = len(query_dataset_same_clothes)
        num_imgs_query_diff = len(query_dataset_diff_clothes)
        num_imgs_gallery = len(gallery_dataset)

        return query_dataset_same_clothes, query_dataset_diff_clothes, gallery_dataset, \
               num_pids, num_imgs_query_same, num_imgs_query_diff, num_imgs_gallery, \
               num_clothes, gallery_idx


================================================
FILE: data/datasets/vcclothes.py
================================================
import os
import re
import glob
import h5py
import random
import math
import logging
import numpy as np
import os.path as osp
from scipy.io import loadmat
from tools.utils import mkdir_if_missing, write_json, read_json


class VCClothes(object):
    """ VC-Clothes

    Reference:
        Wang et al. When Person Re-identification Meets Changing Clothes. In CVPR Workshop, 2020.

    URL: https://wanfb.github.io/dataset.html
    """
    dataset_dir = 'VC-Clothes'
    def __init__(self, root='data', mode='all', **kwargs):
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'gallery')
        # 'all' for all cameras; 'sc' for cam2&3; 'cc' for cam3&4
        self.mode = mode 
        self._check_before_run()

        train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = self._process_dir_train()
        query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = self._process_dir_test()
        num_total_pids = num_train_pids + num_test_pids
        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs
        num_test_imgs = num_query_imgs + num_gallery_imgs 
        num_total_clothes = num_train_clothes + num_test_clothes

        logger = logging.getLogger('reid.dataset')
        logger.info("=> VC-Clothes loaded")
        logger.info("Dataset statistics:")
        logger.info("  ----------------------------------------")
        logger.info("  subset   | # ids | # images | # clothes")
        logger.info("  ----------------------------------------")
        logger.info("  train    | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes))
        logger.info("  test     | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes))
        logger.info("  query    | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs))
        logger.info("  gallery  | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs))
        logger.info("  ----------------------------------------")
        logger.info("  total    | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes))
        logger.info("  ----------------------------------------")

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_train_clothes = num_train_clothes
        self.pid2clothes = pid2clothes

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))

    def _process_dir_train(self):
        img_paths = glob.glob(osp.join(self.train_dir, '*.jpg'))
        img_paths.sort()
        pattern = re.compile(r'(\d+)-(\d+)-(\d+)-(\d+)')

        pid_container = set()
        clothes_container = set()
        for img_path in img_paths:
            pid, camid, clothes, _ = pattern.search(img_path).groups()
            clothes_id = pid + clothes
            pid, camid = int(pid), int(camid)
            pid_container.add(pid)
            clothes_container.add(clothes_id)
        pid_container = sorted(pid_container)
        clothes_container = sorted(clothes_container)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}
        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}

        num_pids = len(pid_container)
        num_clothes = len(clothes_container)

        dataset = []
        pid2clothes = np.zeros((num_pids, num_clothes))
        for img_path in img_paths:
            pid, camid, clothes, _ = pattern.search(img_path).groups()
            clothes_id = pid + clothes
            pid, camid = int(pid), int(camid)
            camid -= 1 # index starts from 0
            pid = pid2label[pid]
            clothes_id = clothes2label[clothes_id]
            dataset.append((img_path, pid, camid, clothes_id))
            pid2clothes[pid, clothes_id] = 1
        
        num_imgs = len(dataset)

        return dataset, num_pids, num_imgs, num_clothes, pid2clothes

    def _process_dir_test(self):
        query_img_paths = glob.glob(osp.join(self.query_dir, '*.jpg'))
        gallery_img_paths = glob.glob(osp.join(self.gallery_dir, '*.jpg'))
        query_img_paths.sort()
        gallery_img_paths.sort()
        pattern = re.compile(r'(\d+)-(\d+)-(\d+)-(\d+)')

        pid_container = set()
        clothes_container = set()
        for img_path in query_img_paths:
            pid, camid, clothes, _ = pattern.search(img_path).groups()
            clothes_id = pid + clothes
            pid, camid = int(pid), int(camid)
            if self.mode == 'sc' and camid not in [2, 3]:
                continue
            if self.mode == 'cc' and camid not in [3, 4]:
                continue
            pid_container.add(pid)
            clothes_container.add(clothes_id)
        for img_path in gallery_img_paths:
            pid, camid, clothes, _ = pattern.search(img_path).groups()
            clothes_id = pid + clothes
            pid, camid = int(pid), int(camid)
            if self.mode == 'sc' and camid not in [2, 3]:
                continue
            if self.mode == 'cc' and camid not in [3, 4]:
                continue
            pid_container.add(pid)
            clothes_container.add(clothes_id)
        pid_container = sorted(pid_container)
        clothes_container = sorted(clothes_container)
        pid2label = {pid:label for label, pid in enumerate(pid_container)}
        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}

        num_pids = len(pid_container)
        num_clothes = len(clothes_container)

        query_dataset = []
        gallery_dataset = []
        for img_path in query_img_paths:
            pid, camid, clothes, _ = pattern.search(img_path).groups()
            clothes_id = pid + clothes
            pid, camid = int(pid), int(camid)
            if self.mode == 'sc' and camid not in [2, 3]:
                continue
            if self.mode == 'cc' and camid not in [3, 4]:
                continue
            camid -= 1 # index starts from 0
            clothes_id = clothes2label[clothes_id]
            query_dataset.append((img_path, pid, camid, clothes_id))

        for img_path in gallery_img_paths:
            pid, camid, clothes, _ = pattern.search(img_path).groups()
            clothes_id = pid + clothes
            pid, camid = int(pid), int(camid)
            if self.mode == 'sc' and camid not in [2, 3]:
                continue
            if self.mode == 'cc' and camid not in [3, 4]:
                continue
            camid -= 1 # index starts from 0
            clothes_id = clothes2label[clothes_id]
            gallery_dataset.append((img_path, pid, camid, clothes_id))
        
        num_imgs_query = len(query_dataset)
        num_imgs_gallery = len(gallery_dataset)

        return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes


def VCClothesSameClothes(root='data', **kwargs):
    return VCClothes(root=root, mode='sc')


def VCClothesClothesChanging(root='data', **kwargs):
    return VCClothes(root=root, mode='cc')


================================================
FILE: data/img_transforms.py
================================================
from torchvision.transforms import *
from PIL import Image
import random
import math


class ResizeWithEqualScale(object):
    """
    Resize an image with equal scale as the original image.

    Args:
        height (int): resized height.
        width (int): resized width.
        interpolation: interpolation manner.
        fill_color (tuple): color for padding.
    """
    def __init__(self, height, width, interpolation=Image.BILINEAR, fill_color=(0,0,0)):
        self.height = height
        self.width = width
        self.interpolation = interpolation
        self.fill_color = fill_color

    def __call__(self, img):
        width, height = img.size
        if self.height / self.width >= height / width:
            height = int(self.width * (height / width))
            width = self.width
        else:
            width = int(self.height * (width / height))
            height = self.height 

        resized_img = img.resize((width, height), self.interpolation)
        new_img = Image.new('RGB', (self.width, self.height), self.fill_color)
        new_img.paste(resized_img, (int((self.width - width) / 2), int((self.height - height) / 2)))

        return new_img


class RandomCroping(object):
    """
    With a probability, first increase image size to (1 + 1/8), and then perform random crop.

    Args:
        p (float): probability of performing this transformation. Default: 0.5.
    """
    def __init__(self, p=0.5, interpolation=Image.BILINEAR):
        self.p = p
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.

        Returns:
            PIL Image: Cropped image.
        """
        width, height = img.size
        if random.uniform(0, 1) >= self.p:
            return img
        
        new_width, new_height = int(round(width * 1.125)), int(round(height * 1.125))
        resized_img = img.resize((new_width, new_height), self.interpolation)
        x_maxrange = new_width - width
        y_maxrange = new_height - height
        x1 = int(round(random.uniform(0, x_maxrange)))
        y1 = int(round(random.uniform(0, y_maxrange)))
        croped_img = resized_img.crop((x1, y1, x1 + width, y1 + height))

        return croped_img


class RandomErasing(object):
    """ 
    Randomly selects a rectangle region in an image and erases its pixels.

    Reference:
        Zhong et al. Random Erasing Data Augmentation. arxiv: 1708.04896, 2017.

    Args:
        probability: The probability that the Random Erasing operation will be performed.
        sl: Minimum proportion of erased area against input image.
        sh: Maximum proportion of erased area against input image.
        r1: Minimum aspect ratio of erased area.
        mean: Erasing value. 
    """
    
    def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
        self.probability = probability
        self.mean = mean
        self.sl = sl
        self.sh = sh
        self.r1 = r1
       
    def __call__(self, img):

        if random.uniform(0, 1) >= self.probability:
            return img

        for attempt in range(100):
            area = img.size()[1] * img.size()[2]
       
            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1/self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))

            if w < img.size()[2] and h < img.size()[1]:
                x1 = random.randint(0, img.size()[1] - h)
                y1 = random.randint(0, img.size()[2] - w)
                if img.size()[0] == 3:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
                    img[1, x1:x1+h, y1:y1+w] = self.mean[1]
                    img[2, x1:x1+h, y1:y1+w] = self.mean[2]
                else:
                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]
                return img

        return img

================================================
FILE: data/samplers.py
================================================
import copy
import math
import random
import numpy as np
from torch import distributed as dist
from collections import defaultdict
from torch.utils.data.sampler import Sampler


class RandomIdentitySampler(Sampler):
    """
    Randomly sample N identities, then for each identity,
    randomly sample K instances, therefore batch size is N*K.

    Args:
        data_source (Dataset): dataset to sample from.
        num_instances (int): number of instances per identity.
    """
    def __init__(self, data_source, num_instances=4):
        self.data_source = data_source
        self.num_instances = num_instances
        self.index_dic = defaultdict(list)
        for index, (_, pid, _, _) in enumerate(data_source):
            self.index_dic[pid].append(index)
        self.pids = list(self.index_dic.keys())
        self.num_identities = len(self.pids)

        # compute number of examples in an epoch
        self.length = 0
        for pid in self.pids:
            idxs = self.index_dic[pid]
            num = len(idxs)
            if num < self.num_instances:
                num = self.num_instances
            self.length += num - num % self.num_instances

    def __iter__(self):
        list_container = []

        for pid in self.pids:
            idxs = copy.deepcopy(self.index_dic[pid])
            if len(idxs) < self.num_instances:
                idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
            random.shuffle(idxs)
            batch_idxs = []
            for idx in idxs:
                batch_idxs.append(idx)
                if len(batch_idxs) == self.num_instances:
                    list_container.append(batch_idxs)
                    batch_idxs = []

        random.shuffle(list_container)

        ret = []
        for batch_idxs in list_container:
            ret.extend(batch_idxs)

        return iter(ret)

    def __len__(self):
        return self.length


class DistributedRandomIdentitySampler(Sampler):
    """
    Randomly sample N identities, then for each identity,
    randomly sample K instances, therefore batch size is N*K.

    Args:
    - data_source (Dataset): dataset to sample from.
    - num_instances (int): number of instances per identity.
    - num_replicas (int, optional): Number of processes participating in
        distributed training. By default, :attr:`world_size` is retrieved from the
        current distributed group.
    - rank (int, optional): Rank of the current process within :attr:`num_replicas`.
        By default, :attr:`rank` is retrieved from the current distributed group.
    - seed (int, optional): random seed used to shuffle the sampler. 
        This number should be identical across all
        processes in the distributed group. Default: ``0``.
    """
    def __init__(self, data_source, num_instances=4, 
                 num_replicas=None, rank=None, seed=0):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                "Invalid rank {}, rank should be in the interval"
                " [0, {}]".format(rank, num_replicas - 1))
        self.num_replicas = num_replicas
        self.rank = rank
        self.seed = seed
        self.epoch = 0

        self.data_source = data_source
        self.num_instances = num_instances
        self.index_dic = defaultdict(list)
        for index, (_, pid, _, _) in enumerate(data_source):
            self.index_dic[pid].append(index)
        self.pids = list(self.index_dic.keys())
        self.num_identities = len(self.pids)

        # compute number of examples in an epoch
        self.length = 0
        for pid in self.pids:
            idxs = self.index_dic[pid]
            num = len(idxs)
            if num < self.num_instances:
                num = self.num_instances
            self.length += num - num % self.num_instances
        assert self.length % self.num_instances == 0

        if self.length // self.num_instances % self.num_replicas != 0: 
            self.num_samples = math.ceil((self.length // self.num_instances - self.num_replicas) / self.num_replicas) * self.num_instances
        else:
            self.num_samples = math.ceil(self.length / self.num_replicas) 
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        # deterministically shuffle based on epoch and seed
        random.seed(self.seed + self.epoch)
        np.random.seed(self.seed + self.epoch)

        list_container = []
        for pid in self.pids:
            idxs = copy.deepcopy(self.index_dic[pid])
            if len(idxs) < self.num_instances:
                idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
            random.shuffle(idxs)
            batch_idxs = []
            for idx in idxs:
                batch_idxs.append(idx)
                if len(batch_idxs) == self.num_instances:
                    list_container.append(batch_idxs)
                    batch_idxs = []
        random.shuffle(list_container)

        # remove tail of data to make it evenly divisible.
        list_container = list_container[:self.total_size//self.num_instances]
        assert len(list_container) == self.total_size//self.num_instances

        # subsample
        list_container = list_container[self.rank:self.total_size//self.num_instances:self.num_replicas]
        assert len(list_container) == self.num_samples//self.num_instances

        ret = []
        for batch_idxs in list_container:
            ret.extend(batch_idxs)

        return iter(ret)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        """
        Sets the epoch for this sampler. This ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.

        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch


class DistributedInferenceSampler(Sampler):
    """
    refer to: https://github.com/huggingface/transformers/blob/447808c85f0e6d6b0aeeb07214942bf1e578f9d2/src/transformers/trainer_pt_utils.py

    Distributed Sampler that subsamples indicies sequentially,
    making it easier to collate all results at the end.
    Even though we only use this sampler for eval and predict (no training),
    which means that the model params won't have to be synced (i.e. will not hang
    for synchronization even if varied number of forward passes), we still add extra
    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
    """
    def __init__(self, dataset, rank=None, num_replicas=None):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank

        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        indices = list(range(len(self.dataset)))
        # add extra samples to make it evenly divisible
        indices += [indices[-1]] * (self.total_size - len(indices))
        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
        return iter(indices)

    def __len__(self):
        return self.num_samples

================================================
FILE: data/spatial_transforms.py
================================================
import random
import math
import numbers
import collections
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image, ImageOps
try:
    import accimage
except ImportError:
    accimage = None


class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def randomize_parameters(self):
        for t in self.transforms:
            t.randomize_parameters()


class ToTensor(object):
    """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.
    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __init__(self, norm_value=255):
        self.norm_value = norm_value

    def __call__(self, pic):
        """
        Args:
            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.
        Returns:
            Tensor: Converted image.
        """
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic.transpose((2, 0, 1)))
            # backward compatibility
            return img.float().div(self.norm_value)

        if accimage is not None and isinstance(pic, accimage.Image):
            nppic = np.zeros(
                [pic.channels, pic.height, pic.width], dtype=np.float32)
            pic.copyto(nppic)
            return torch.from_numpy(nppic)

        # handle PIL Image
        if pic.mode == 'I':
            img = torch.from_numpy(np.array(pic, np.int32, copy=False))
        elif pic.mode == 'I;16':
            img = torch.from_numpy(np.array(pic, np.int16, copy=False))
        else:
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
        if pic.mode == 'YCbCr':
            nchannel = 3
        elif pic.mode == 'I;16':
            nchannel = 1
        else:
            nchannel = len(pic.mode)
        img = img.view(pic.size[1], pic.size[0], nchannel)
        # put it from HWC to CHW format
        # yikes, this transpose takes 80% of the loading time/CPU
        img = img.transpose(0, 1).transpose(0, 2).contiguous()
        if isinstance(img, torch.ByteTensor):
            return img.float().div(self.norm_value)
        else:
            return img

    def randomize_parameters(self):
        pass


class Normalize(object):
    """Normalize an tensor image with mean and standard deviation.
    Given mean: (R, G, B) and std: (R, G, B),
    will normalize each channel of the torch.*Tensor, i.e.
    channel = (channel - mean) / std

    Args:
        mean (sequence): Sequence of means for R, G, B channels respecitvely.
        std (sequence): Sequence of standard deviations for R, G, B channels
            respecitvely.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        # TODO: make efficient
        for t, m, s in zip(tensor, self.mean, self.std):
            t.sub_(m).div_(s)
        return tensor

    def randomize_parameters(self):
        pass


class Scale(object):
    """Rescale the input PIL.Image to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (w, h), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be scaled.
        Returns:
            PIL.Image: Rescaled image.
        """
        if isinstance(self.size, int):
            w, h = img.size
            if (w <= h and w == self.size) or (h <= w and h == self.size):
                return img
            if w < h:
                ow = self.size
                oh = int(self.size * h / w)
                return img.resize((ow, oh), self.interpolation)
            else:
                oh = self.size
                ow = int(self.size * w / h)
                return img.resize((ow, oh), self.interpolation)
        else:
            return img.resize(self.size[::-1], self.interpolation)

    def randomize_parameters(self):
        pass


class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""

    def __call__(self, img):
        """
        Args:
            img (PIL.Image): Image to be flipped.
        Returns:
            PIL.Image: Randomly flipped image.
        """
        if self.p < 0.5:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img

    def randomize_parameters(self):
        self.p = random.random()


class RandomCrop(object):
    """
    With a probability, first increase image size to (1 + 1/8), and then perform random crop.

    Args:
        height (int): target height.
        width (int): target width.
        p (float): probability of performing this transformation. Default: 0.5.
    """
    def __init__(self, size, p=0.5, interpolation=Image.BILINEAR):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

        self.height, self.width = self.size
        self.p = p
        self.interpolation = interpolation

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.

        Returns:
            PIL Image: Cropped image.
        """
        if not self.cropping:
            return img.resize((self.width, self.height), self.interpolation)
        
        new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125))
        resized_img = img.resize((new_width, new_height), self.interpolation)
        x_maxrange = new_width - self.width
        y_maxrange = new_height - self.height
        x1 = int(round(self.tl_x * x_maxrange))
        y1 = int(round(self.tl_y * y_maxrange))
        return resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))

    def randomize_parameters(self):
        self.cropping = random.uniform(0, 1) < self.p
        self.tl_x = random.random()
        self.tl_y = random.random()


class RandomErasing(object):
    """ 
    Randomly selects a rectangle region in an image and erases its pixels.

    Reference:
        Zhong et al. Random Erasing Data Augmentation. arxiv: 1708.04896, 2017.
        
    Args:
         probability: The probability that the Random Erasing operation will be performed.
         sl: Minimum proportion of erased area against input image.
         sh: Maximum proportion of erased area against input image.
         r1: Minimum aspect ratio of erased area.
         mean: Erasing value. 
    """
    
    def __init__(self, height=256, width=128, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.485, 0.456, 0.406]):
        self.probability = probability
        self.mean = mean
        self.sl = sl
        self.sh = sh
        self.r1 = r1
        self.height = height
        self.width = width
       
    def __call__(self, img):
        if self.re:
            return img

        if img.size()[0] == 3:
            img[0, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[0]
            img[1, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[1]
            img[2, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[2]
        else:
            img[0, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[0]
        return img

    def randomize_parameters(self):
        self.re = random.uniform(0, 1) < self.probability
        self.h, self.w, self.x1, self.y1 = 0, 0, 0, 0
        whether_re = False
        if self.re:
            for attempt in range(100):
                area = self.height*self.width

                target_area = random.uniform(self.sl, self.sh) * area
                aspect_ratio = random.uniform(self.r1, 1/self.r1)

                self.h = int(round(math.sqrt(target_area * aspect_ratio)))
                self.w = int(round(math.sqrt(target_area / aspect_ratio)))
                if self.w < self.width and self.h < self.height:
                    self.x1 = random.randint(0, self.height - self.h)
                    self.y1 = random.randint(0, self.width - self.w)
                    whether_re = True
                    break

        self.re = whether_re

================================================
FILE: data/temporal_transforms.py
================================================
import random
import numpy as np


class TemporalRandomCrop(object):
    """Temporally crop the given frame indices at a random location.

    If the number of frames is less than the size,
    loop the indices as many times as necessary to satisfy the size.

    Args:
        size (int): Desired output size of the crop.
        stride (int): Temporal sampling stride
    """

    def __init__(self, size=4, stride=8):
        self.size = size
        self.stride = stride

    def __call__(self, frame_indices):
        """
        Args:
            frame_indices (list): frame indices to be cropped.
        Returns:
            list: Cropped frame indices.
        """
        frame_indices = list(frame_indices)

        if len(frame_indices) >= self.size * self.stride:
            rand_end = len(frame_indices) - (self.size - 1) * self.stride - 1
            begin_index = random.randint(0, rand_end)
            end_index = begin_index + (self.size - 1) * self.stride + 1
            out = frame_indices[begin_index:end_index:self.stride]
        elif len(frame_indices) >= self.size:
            clips = []
            for i in range(self.size):
                    clips.append(frame_indices[len(frame_indices)//self.size*i : len(frame_indices)//self.size*(i+1)])
            out = []
            for i in range(self.size):
                out.append(random.choice(clips[i]))
        else:
            index = np.random.choice(len(frame_indices), size=self.size, replace=True)
            index.sort()
            out = [frame_indices[index[i]] for i in range(self.size)]

        return out


class TemporalBeginCrop(object):
    """Temporally crop the given frame indices at a beginning.

    If the number of frames is less than the size,
    loop the indices as many times as necessary to satisfy the size.

    Args:
        size (int): Desired output size of the crop.
        stride (int): Temporal sampling stride
    """

    def __init__(self, size=8, stride=4):
        self.size = size
        self.stride = stride
        
    def __call__(self, frame_indices):
        frame_indices = list(frame_indices)

        if len(frame_indices) >= self.size * self.stride:
            out = frame_indices[0 : self.size * self.stride : self.stride]
        else:
            out = frame_indices[0 : self.size]
            while len(out) < self.size:
                for index in out:
                    if len(out) >= self.size:
                        break
                    out.append(index)

        return out


class TemporalDivisionCrop(object):
    """Temporally crop the given frame indices by TSN.

    Args:
        size (int): Desired output size of the crop.
    """
    def __init__(self, size=4):
        self.size = size

    def __call__(self, frame_indices):
        """
        Args:
            frame_indices (list): frame indices to be cropped.
        Returns:
            list: Cropped frame indices.
        """
        frame_indices = list(frame_indices)

        if len(frame_indices) >= self.size:
            clips = []
            for i in range(self.size):
                clips.append(frame_indices[len(frame_indices)//self.size*i : len(frame_indices)//self.size*(i+1)])
            out = []
            for i in range(self.size):
                out.append(random.choice(clips[i]))
        else:
            index = np.random.choice(len(frame_indices), size=self.size, replace=True)
            index.sort()
            out = [frame_indices[index[i]] for i in range(self.size)]

        return out


================================================
FILE: losses/__init__.py
================================================
from torch import nn
from losses.cross_entropy_loss_with_label_smooth import CrossEntropyWithLabelSmooth
from losses.triplet_loss import TripletLoss
from losses.contrastive_loss import ContrastiveLoss
from losses.arcface_loss import ArcFaceLoss
from losses.cosface_loss import CosFaceLoss, PairwiseCosFaceLoss
from losses.circle_loss import CircleLoss, PairwiseCircleLoss
from losses.clothes_based_adversarial_loss import ClothesBasedAdversarialLoss, ClothesBasedAdversarialLossWithMemoryBank


def build_losses(config, num_train_clothes):
    # Build identity classification loss
    if config.LOSS.CLA_LOSS == 'crossentropy':
        criterion_cla = nn.CrossEntropyLoss()
    elif config.LOSS.CLA_LOSS == 'crossentropylabelsmooth':
        criterion_cla = CrossEntropyWithLabelSmooth()
    elif config.LOSS.CLA_LOSS == 'arcface':
        criterion_cla = ArcFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)
    elif config.LOSS.CLA_LOSS == 'cosface':
        criterion_cla = CosFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)
    elif config.LOSS.CLA_LOSS == 'circle':
        criterion_cla = CircleLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)
    else:
        raise KeyError("Invalid classification loss: '{}'".format(config.LOSS.CLA_LOSS))

    # Build pairwise loss
    if config.LOSS.PAIR_LOSS == 'triplet':
        criterion_pair = TripletLoss(margin=config.LOSS.PAIR_M)
    elif config.LOSS.PAIR_LOSS == 'contrastive':
        criterion_pair = ContrastiveLoss(scale=config.LOSS.PAIR_S)
    elif config.LOSS.PAIR_LOSS == 'cosface':
        criterion_pair = PairwiseCosFaceLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M)
    elif config.LOSS.PAIR_LOSS == 'circle':
        criterion_pair = PairwiseCircleLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M)
    else:
        raise KeyError("Invalid pairwise loss: '{}'".format(config.LOSS.PAIR_LOSS))

    # Build clothes classification loss
    if config.LOSS.CLOTHES_CLA_LOSS == 'crossentropy':
        criterion_clothes = nn.CrossEntropyLoss()
    elif config.LOSS.CLOTHES_CLA_LOSS == 'cosface':
        criterion_clothes = CosFaceLoss(scale=config.LOSS.CLA_S, margin=0)
    else:
        raise KeyError("Invalid clothes classification loss: '{}'".format(config.LOSS.CLOTHES_CLA_LOSS))

    # Build clothes-based adversarial loss
    if config.LOSS.CAL == 'cal':
        criterion_cal = ClothesBasedAdversarialLoss(scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON)
    elif config.LOSS.CAL == 'calwithmemory':
        criterion_cal = ClothesBasedAdversarialLossWithMemoryBank(num_clothes=num_train_clothes, feat_dim=config.MODEL.FEATURE_DIM,
                             momentum=config.LOSS.MOMENTUM, scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON)
    else:
        raise KeyError("Invalid clothing classification loss: '{}'".format(config.LOSS.CAL))

    return criterion_cla, criterion_pair, criterion_clothes, criterion_cal


================================================
FILE: losses/arcface_loss.py
================================================
import math
import torch
import torch.nn.functional as F
from torch import nn


class ArcFaceLoss(nn.Module):
    """ ArcFace loss.

    Reference:
        Deng et al. ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019.

    Args:
        scale (float): scaling factor.
        margin (float): pre-defined margin.
    """
    def __init__(self, scale=16, margin=0.1):
        super().__init__()
        self.s = scale
        self.m = margin

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (batch_size)
        """
        # get a one-hot index
        index = inputs.data * 0.0 
        index.scatter_(1, targets.data.view(-1, 1), 1)
        index = index.bool()

        cos_m = math.cos(self.m)
        sin_m = math.sin(self.m)
        cos_t = inputs[index]
        sin_t = torch.sqrt(1.0 - cos_t * cos_t)
        cos_t_add_m = cos_t * cos_m  - sin_t * sin_m

        cond_v = cos_t - math.cos(math.pi - self.m)
        cond = F.relu(cond_v)
        keep = cos_t - math.sin(math.pi - self.m) * self.m

        cos_t_add_m = torch.where(cond.bool(), cos_t_add_m, keep)

        output = inputs * 1.0 
        output[index] = cos_t_add_m
        output = self.s * output

        return F.cross_entropy(output, targets)


================================================
FILE: losses/circle_loss.py
================================================
import torch
import torch.nn.functional as F
from torch import nn
from torch import distributed as dist
from losses.gather import GatherLayer


class CircleLoss(nn.Module):
    """ Circle Loss based on the predictions of classifier.

    Reference:
        Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020.

    Args:
        scale (float): scaling factor.
        margin (float): pre-defined margin.
    """
    def __init__(self, scale=96, margin=0.3, **kwargs):
        super().__init__()
        self.s = scale
        self.m = margin

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (batch_size)
        """
        mask = torch.zeros_like(inputs).cuda()
        mask.scatter_(1, targets.view(-1, 1), 1.0)
    
        pos_scale = self.s * F.relu(1 + self.m - inputs.detach())
        neg_scale = self.s * F.relu(inputs.detach() + self.m)
        scale_matrix = pos_scale * mask + neg_scale * (1 - mask)

        scores = (inputs - (1 - self.m) * mask - self.m * (1 - mask)) * scale_matrix
        
        loss = F.cross_entropy(scores, targets)

        return loss


class PairwiseCircleLoss(nn.Module):
    """ Circle Loss among sample pairs.

    Reference:
        Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020.

    Args:
        scale (float): scaling factor.
        margin (float): pre-defined margin.
    """
    def __init__(self, scale=48, margin=0.35, **kwargs):
        super().__init__()
        self.s = scale
        self.m = margin

    def forward(self, inputs, targets):
        """
        Args:
            inputs: sample features (before classifier) with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (batch_size)
        """
        # l2-normalize
        inputs = F.normalize(inputs, p=2, dim=1)

        # gather all samples from different GPUs as gallery to compute pairwise loss.
        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)
        m, n = targets.size(0), gallery_targets.size(0)

        # compute cosine similarity
        similarities = torch.matmul(inputs, gallery_inputs.t())
        
        # get mask for pos/neg pairs
        targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1)
        mask = torch.eq(targets, gallery_targets.T).float().cuda()
        mask_self = torch.zeros_like(mask)
        rank = dist.get_rank()
        mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda()
        mask_pos = mask - mask_self
        mask_neg = 1 - mask

        pos_scale = self.s * F.relu(1 + self.m - similarities.detach())
        neg_scale = self.s * F.relu(similarities.detach() + self.m)
        scale_matrix = pos_scale * mask_pos + neg_scale * mask_neg

        scores = (similarities - self.m) * mask_neg + (1 - self.m - similarities) * mask_pos
        scores = scores * scale_matrix
        
        neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1)
        pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1)

        loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean()

        return loss


================================================
FILE: losses/clothes_based_adversarial_loss.py
================================================
import torch
import torch.nn.functional as F
from torch import nn
from losses.gather import GatherLayer


class ClothesBasedAdversarialLoss(nn.Module):
    """ Clothes-based Adversarial Loss.

    Reference:
        Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022.

    Args:
        scale (float): scaling factor.
        epsilon (float): a trade-off hyper-parameter.
    """
    def __init__(self, scale=16, epsilon=0.1):
        super().__init__()
        self.scale = scale
        self.epsilon = epsilon

    def forward(self, inputs, targets, positive_mask):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (batch_size)
            positive_mask: positive mask matrix with shape (batch_size, num_classes). The clothes classes with 
                the same identity as the anchor sample are defined as positive clothes classes and their mask 
                values are 1. The clothes classes with different identities from the anchor sample are defined 
                as negative clothes classes and their mask values in positive_mask are 0.
        """
        inputs = self.scale * inputs
        negtive_mask = 1 - positive_mask
        identity_mask = torch.zeros(inputs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda()

        exp_logits = torch.exp(inputs)
        log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits)
        log_prob = inputs - log_sum_exp_pos_and_all_neg

        mask = (1 - self.epsilon) * identity_mask + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask
        loss = (- mask * log_prob).sum(1).mean()

        return loss


class ClothesBasedAdversarialLossWithMemoryBank(nn.Module):
    """ Clothes-based Adversarial Loss between mini batch and the samples in memory.

    Reference:
        Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022.

    Args:
        num_clothes (int): the number of clothes classes.
        feat_dim (int): the dimensions of feature.
        momentum (float): momentum to update memory.
        scale (float): scaling factor.
        epsilon (float): a trade-off hyper-parameter.
    """
    def __init__(self, num_clothes, feat_dim, momentum=0., scale=16, epsilon=0.1):
        super().__init__()
        self.num_clothes = num_clothes
        self.feat_dim = feat_dim
        self.momentum = momentum
        self.epsilon = epsilon
        self.scale = scale

        self.register_buffer('feature_memory', torch.zeros((num_clothes, feat_dim)))
        self.register_buffer('label_memory', torch.zeros(num_clothes, dtype=torch.int64) - 1)
        self.has_been_filled = False

    def forward(self, inputs, targets, positive_mask):
        """
        Args:
            inputs: sample features (before classifier) with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (batch_size)
            positive_mask: positive mask matrix with shape (batch_size, num_classes). 
        """
        # gather all samples from different GPUs to update memory.
        gathered_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
        gathered_targets = torch.cat(GatherLayer.apply(targets), dim=0)
        self._update_memory(gathered_inputs.detach(), gathered_targets)

        inputs_norm = F.normalize(inputs, p=2, dim=1)
        memory_norm = F.normalize(self.feature_memory.detach(), p=2, dim=1)
        similarities = torch.matmul(inputs_norm, memory_norm.t()) * self.scale

        negtive_mask = 1 - positive_mask
        mask_identity = torch.zeros(positive_mask.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda()

        if not self.has_been_filled:
            invalid_index = self.label_memory == -1
            positive_mask[:, invalid_index] = 0
            negtive_mask[:, invalid_index] = 0
            if sum(invalid_index.type(torch.int)) == 0:
                self.has_been_filled = True
                print('Memory bank is full')

        # compute log_prob
        exp_logits = torch.exp(similarities)
        log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits)
        log_prob = similarities - log_sum_exp_pos_and_all_neg

        # compute mean of log-likelihood over positive
        mask = (1 - self.epsilon) * mask_identity + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask
        loss = (- mask * log_prob).sum(1).mean()
        
        return loss

    def _update_memory(self, features, labels):
        label_to_feat = {}
        for x, y in zip(features, labels):
            if y not in label_to_feat:
                label_to_feat[y] = [x.unsqueeze(0)]
            else:
                label_to_feat[y].append(x.unsqueeze(0))
        if not self.has_been_filled:
            for y in label_to_feat:
                feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0)
                self.feature_memory[y] = feat
                self.label_memory[y] = y
        else:
            for y in label_to_feat:
                feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0)
                self.feature_memory[y] = self.momentum * self.feature_memory[y] + (1. - self.momentum) * feat
                # self.embedding_memory[y] /= self.embedding_memory[y].norm()

================================================
FILE: losses/contrastive_loss.py
================================================
import torch
import torch.nn.functional as F
from torch import nn
from torch import distributed as dist
from losses.gather import GatherLayer


class ContrastiveLoss(nn.Module):
    """ Supervised Contrastive Learning Loss among sample pairs.

    Args:
        scale (float): scaling factor.
    """
    def __init__(self, scale=16, **kwargs):
        super().__init__()
        self.s = scale

    def forward(self, inputs, targets):
        """
        Args:
            inputs: sample features (before classifier) with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (batch_size)
        """
        # l2-normalize
        inputs = F.normalize(inputs, p=2, dim=1)

        # gather all samples from different GPUs as gallery to compute pairwise loss.
        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)
        m, n = targets.size(0), gallery_targets.size(0)

        # compute cosine similarity
        similarities = torch.matmul(inputs, gallery_inputs.t()) * self.s
        
        # get mask for pos/neg pairs
        targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1)
        mask = torch.eq(targets, gallery_targets.T).float().cuda()
        mask_self = torch.zeros_like(mask)
        rank = dist.get_rank()
        mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda()
        mask_pos = mask - mask_self
        mask_neg = 1 - mask

        # compute log_prob
        exp_logits = torch.exp(similarities) * (1 - mask_self)
        # log_prob = similarities - torch.log(exp_logits.sum(1, keepdim=True))
        log_sum_exp_pos_and_all_neg = torch.log((exp_logits * mask_neg).sum(1, keepdim=True) + exp_logits)
        log_prob = similarities - log_sum_exp_pos_and_all_neg

        # compute mean of log-likelihood over positive
        loss = (mask_pos * log_prob).sum(1) / mask_pos.sum(1)

        loss = - loss.mean()

        return loss

================================================
FILE: losses/cosface_loss.py
================================================
import torch
import torch.nn.functional as F
from torch import nn
from torch import distributed as dist
from losses.gather import GatherLayer


class CosFaceLoss(nn.Module):
    """ CosFace Loss based on the predictions of classifier.

    Reference:
        Wang et al. CosFace: Large Margin Cosine Loss for Deep Face Recognition. In CVPR, 2018.

    Args:
        scale (float): scaling factor.
        margin (float): pre-defined margin.
    """
    def __init__(self, scale=16, margin=0.1, **kwargs):
        super().__init__()
        self.s = scale
        self.m = margin

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (batch_size)
        """
        one_hot = torch.zeros_like(inputs)
        one_hot.scatter_(1, targets.view(-1, 1), 1.0)

        output = self.s * (inputs - one_hot * self.m)

        return F.cross_entropy(output, targets)


class PairwiseCosFaceLoss(nn.Module):
    """ CosFace Loss among sample pairs.

    Reference:
        Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020.

    Args:
        scale (float): scaling factor.
        margin (float): pre-defined margin.
    """
    def __init__(self, scale=16, margin=0):
        super().__init__()
        self.s = scale
        self.m = margin

    def forward(self, inputs, targets):
        """
        Args:
            inputs: sample features (before classifier) with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (batch_size)
        """
        # l2-normalize
        inputs = F.normalize(inputs, p=2, dim=1)

        # gather all samples from different GPUs as gallery to compute pairwise loss.
        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)
        m, n = targets.size(0), gallery_targets.size(0)

        # compute cosine similarity
        similarities = torch.matmul(inputs, gallery_inputs.t())
        
        # get mask for pos/neg pairs
        targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1)
        mask = torch.eq(targets, gallery_targets.T).float().cuda()
        mask_self = torch.zeros_like(mask)
        rank = dist.get_rank()
        mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda()
        mask_pos = mask - mask_self
        mask_neg = 1 - mask

        scores = (similarities + self.m) * mask_neg - similarities * mask_pos
        scores = scores * self.s
        
        neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1)
        pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1)

        loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean()

        return loss

================================================
FILE: losses/cross_entropy_loss_with_label_smooth.py
================================================
import torch
from torch import nn


class CrossEntropyWithLabelSmooth(nn.Module):
    """ Cross entropy loss with label smoothing regularization.

    Reference:
        Szegedy et al. Rethinking the Inception Architecture for Computer Vision. In CVPR, 2016.
    Equation: 
        y = (1 - epsilon) * y + epsilon / K.

    Args:
        epsilon (float): a hyper-parameter in the above equation.
    """
    def __init__(self, epsilon=0.1):
        super().__init__()
        self.epsilon = epsilon
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (batch_size)
        """
        _, num_classes = inputs.size()
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda()
        targets = (1 - self.epsilon) * targets + self.epsilon / num_classes
        loss = (- targets * log_probs).mean(0).sum()

        return loss


================================================
FILE: losses/gather.py
================================================
import torch
import torch.distributed as dist


class GatherLayer(torch.autograd.Function):
    """Gather tensors from all process, supporting backward propagation."""

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)

        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        (input,) = ctx.saved_tensors
        grad_out = torch.zeros_like(input)

        # dist.reduce_scatter(grad_out, list(grads))
        # grad_out.div_(dist.get_world_size())

        grad_out[:] = grads[dist.get_rank()]

        return grad_out

================================================
FILE: losses/triplet_loss.py
================================================
import math
import torch
import torch.nn.functional as F
from torch import nn
from losses.gather import GatherLayer


class TripletLoss(nn.Module):
    """ Triplet loss with hard example mining.

    Reference:
        Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.

    Args:
        margin (float): pre-defined margin.

    Note that we use cosine similarity, rather than Euclidean distance in the original paper.
    """
    def __init__(self, margin=0.3):
        super().__init__()
        self.m = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, targets):
        """
        Args:
            inputs: sample features (before classifier) with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (batch_size)
        """
        # l2-normlize
        inputs = F.normalize(inputs, p=2, dim=1)

        # gather all samples from different GPUs as gallery to compute pairwise loss.
        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)
        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)

        # compute distance
        dist = 1 - torch.matmul(inputs, gallery_inputs.t()) # values in [0, 2]

        # get positive and negative masks
        targets, gallery_targets = targets.view(-1,1), gallery_targets.view(-1,1)
        mask_pos = torch.eq(targets, gallery_targets.T).float().cuda()
        mask_neg = 1 - mask_pos

        # For each anchor, find the hardest positive and negative pairs
        dist_ap, _ = torch.max((dist - mask_neg * 99999999.), dim=1)
        dist_an, _ = torch.min((dist + mask_pos * 99999999.), dim=1)

        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)

        return loss

================================================
FILE: main.py
================================================
import os
import sys
import time
import datetime
import argparse
import logging
import os.path as osp
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import distributed as dist
from apex import amp

from configs.default_img import get_img_config
from configs.default_vid import get_vid_config
from data import build_dataloader
from models import build_model
from losses import build_losses
from tools.utils import save_checkpoint, set_seed, get_logger
from train import train_cal, train_cal_with_memory
from test import test, test_prcc


VID_DATASET = ['ccvid']


def parse_option():
    parser = argparse.ArgumentParser(description='Train clothes-changing re-id model with clothes-based adversarial loss')
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file')
    # Datasets
    parser.add_argument('--root', type=str, help="your root path to data directory")
    parser.add_argument('--dataset', type=str, default='ltcc', help="ltcc, prcc, vcclothes, ccvid, last, deepchange")
    # Miscs
    parser.add_argument('--output', type=str, help="your output path to save model and logs")
    parser.add_argument('--resume', type=str, metavar='PATH')
    parser.add_argument('--amp', action='store_true', help="automatic mixed precision")
    parser.add_argument('--eval', action='store_true', help="evaluation only")
    parser.add_argument('--tag', type=str, help='tag for log file')
    parser.add_argument('--gpu', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')

    args, unparsed = parser.parse_known_args()
    if args.dataset in VID_DATASET:
        config = get_vid_config(args)
    else:
        config = get_img_config(args)

    return config


def main(config):
    # Build dataloader
    if config.DATA.DATASET == 'prcc':
        trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler = build_dataloader(config)
    else:
        trainloader, queryloader, galleryloader, dataset, train_sampler = build_dataloader(config)
    # Define a matrix pid2clothes with shape (num_pids, num_clothes). 
    # pid2clothes[i, j] = 1 when j-th clothes belongs to i-th identity. Otherwise, pid2clothes[i, j] = 0.
    pid2clothes = torch.from_numpy(dataset.pid2clothes)

    # Build model
    model, classifier, clothes_classifier = build_model(config, dataset.num_train_pids, dataset.num_train_clothes)
    # Build identity classification loss, pairwise loss, clothes classificaiton loss, and adversarial loss.
    criterion_cla, criterion_pair, criterion_clothes, criterion_adv = build_losses(config, dataset.num_train_clothes)
    # Build optimizer
    parameters = list(model.parameters()) + list(classifier.parameters())
    if config.TRAIN.OPTIMIZER.NAME == 'adam':
        optimizer = optim.Adam(parameters, lr=config.TRAIN.OPTIMIZER.LR, 
                               weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
        optimizer_cc = optim.Adam(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 
                                  weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
    elif config.TRAIN.OPTIMIZER.NAME == 'adamw':
        optimizer = optim.AdamW(parameters, lr=config.TRAIN.OPTIMIZER.LR, 
                               weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
        optimizer_cc = optim.AdamW(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 
                                  weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)
    elif config.TRAIN.OPTIMIZER.NAME == 'sgd':
        optimizer = optim.SGD(parameters, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 
                              weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)
        optimizer_cc = optim.SGD(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 
                              weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)
    else:
        raise KeyError("Unknown optimizer: {}".format(config.TRAIN.OPTIMIZER.NAME))
    # Build lr_scheduler
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE, 
                                         gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE)

    start_epoch = config.TRAIN.START_EPOCH
    if config.MODEL.RESUME:
        logger.info("Loading checkpoint from '{}'".format(config.MODEL.RESUME))
        checkpoint = torch.load(config.MODEL.RESUME)
        model.load_state_dict(checkpoint['model_state_dict'])
        classifier.load_state_dict(checkpoint['classifier_state_dict'])
        if config.LOSS.CAL == 'calwithmemory':
            criterion_adv.load_state_dict(checkpoint['clothes_classifier_state_dict'])
        else:
            clothes_classifier.load_state_dict(checkpoint['clothes_classifier_state_dict'])
        start_epoch = checkpoint['epoch']

    local_rank = dist.get_rank()
    model = model.cuda(local_rank)
    classifier = classifier.cuda(local_rank)
    if config.LOSS.CAL == 'calwithmemory':
        criterion_adv = criterion_adv.cuda(local_rank)
    else:
        clothes_classifier = clothes_classifier.cuda(local_rank)
    torch.cuda.set_device(local_rank)

    if config.TRAIN.AMP:
        [model, classifier], optimizer = amp.initialize([model, classifier], optimizer, opt_level="O1")
        if config.LOSS.CAL != 'calwithmemory':
            clothes_classifier, optimizer_cc = amp.initialize(clothes_classifier, optimizer_cc, opt_level="O1")

    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
    classifier = nn.parallel.DistributedDataParallel(classifier, device_ids=[local_rank], output_device=local_rank)
    if config.LOSS.CAL != 'calwithmemory':
        clothes_classifier = nn.parallel.DistributedDataParallel(clothes_classifier, device_ids=[local_rank], output_device=local_rank)

    if config.EVAL_MODE:
        logger.info("Evaluate only")
        with torch.no_grad():
            if config.DATA.DATASET == 'prcc':
                test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)
            else:
                test(config, model, queryloader, galleryloader, dataset)
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    logger.info("==> Start training")
    for epoch in range(start_epoch, config.TRAIN.MAX_EPOCH):
        train_sampler.set_epoch(epoch)
        start_train_time = time.time()
        if config.LOSS.CAL == 'calwithmemory':
            train_cal_with_memory(config, epoch, model, classifier, criterion_cla, criterion_pair, 
                criterion_adv, optimizer, trainloader, pid2clothes)
        else:
            train_cal(config, epoch, model, classifier, clothes_classifier, criterion_cla, criterion_pair, 
                criterion_clothes, criterion_adv, optimizer, optimizer_cc, trainloader, pid2clothes)
        train_time += round(time.time() - start_train_time)        
        
        if (epoch+1) > config.TEST.START_EVAL and config.TEST.EVAL_STEP > 0 and \
            (epoch+1) % config.TEST.EVAL_STEP == 0 or (epoch+1) == config.TRAIN.MAX_EPOCH:
            logger.info("==> Test")
            torch.cuda.empty_cache()
            if config.DATA.DATASET == 'prcc':
                rank1 = test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)
            else:
                rank1 = test(config, model, queryloader, galleryloader, dataset)
            torch.cuda.empty_cache()
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            model_state_dict = model.module.state_dict()
            classifier_state_dict = classifier.module.state_dict()
            if config.LOSS.CAL == 'calwithmemory':
                clothes_classifier_state_dict = criterion_adv.state_dict()
            else:
                clothes_classifier_state_dict = clothes_classifier.module.state_dict()
            if local_rank == 0:
                save_checkpoint({
                    'model_state_dict': model_state_dict,
                    'classifier_state_dict': classifier_state_dict,
                    'clothes_classifier_state_dict': clothes_classifier_state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best, osp.join(config.OUTPUT, 'checkpoint_ep' + str(epoch+1) + '.pth.tar'))
        scheduler.step()

    logger.info("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    logger.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time))
    

if __name__ == '__main__':
    config = parse_option()
    # Set GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU
    # Init dist
    dist.init_process_group(backend="nccl", init_method='env://')
    local_rank = dist.get_rank()
    # Set random seed
    set_seed(config.SEED + local_rank)
    # get logger
    if not config.EVAL_MODE:
        output_file = osp.join(config.OUTPUT, 'log_train_.log')
    else:
        output_file = osp.join(config.OUTPUT, 'log_test.log')
    logger = get_logger(output_file, local_rank, 'reid')
    logger.info("Config:\n-----------------------------------------")
    logger.info(config)
    logger.info("-----------------------------------------")

    main(config)

================================================
FILE: models/__init__.py
================================================
import logging
from models.classifier import Classifier, NormalizedClassifier
from models.img_resnet import ResNet50
from models.vid_resnet import C2DResNet50, I3DResNet50, AP3DResNet50, NLResNet50, AP3DNLResNet50


__factory = {
    'resnet50': ResNet50,
    'c2dres50': C2DResNet50,
    'i3dres50': I3DResNet50,
    'ap3dres50': AP3DResNet50,
    'nlres50': NLResNet50,
    'ap3dnlres50': AP3DNLResNet50,
}


def build_model(config, num_identities, num_clothes):
    logger = logging.getLogger('reid.model')
    # Build backbone
    logger.info("Initializing model: {}".format(config.MODEL.NAME))
    if config.MODEL.NAME not in __factory.keys():
        raise KeyError("Invalid model: '{}'".format(config.MODEL.NAME))
    else:
        logger.info("Init model: '{}'".format(config.MODEL.NAME))
        model = __factory[config.MODEL.NAME](config)
    logger.info("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0))

    # Build classifier
    if config.LOSS.CLA_LOSS in ['crossentropy', 'crossentropylabelsmooth']:
        identity_classifier = Classifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities)
    else:
        identity_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities)

    clothes_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_clothes)

    return model, identity_classifier, clothes_classifier

================================================
FILE: models/classifier.py
================================================
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from torch.nn import Parameter


__all__ = ['Classifier', 'NormalizedClassifier']


class Classifier(nn.Module):
    def __init__(self, feature_dim, num_classes):
        super().__init__()
        self.classifier = nn.Linear(feature_dim, num_classes)
        init.normal_(self.classifier.weight.data, std=0.001)
        init.constant_(self.classifier.bias.data, 0.0)

    def forward(self, x):
        y = self.classifier(x)

        return y
        

class NormalizedClassifier(nn.Module):
    def __init__(self, feature_dim, num_classes):
        super().__init__()
        self.weight = Parameter(torch.Tensor(num_classes, feature_dim))
        self.weight.data.uniform_(-1, 1).renorm_(2,0,1e-5).mul_(1e5) 

    def forward(self, x):
        w = self.weight  

        x = F.normalize(x, p=2, dim=1)
        w = F.normalize(w, p=2, dim=1)

        return F.linear(x, w)





================================================
FILE: models/img_resnet.py
================================================
import torchvision
from torch import nn
from torch.nn import init
from models.utils import pooling
        

class ResNet50(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()

        resnet50 = torchvision.models.resnet50(pretrained=True)
        if config.MODEL.RES4_STRIDE == 1:
            resnet50.layer4[0].conv2.stride=(1, 1)
            resnet50.layer4[0].downsample[0].stride=(1, 1) 
        self.base = nn.Sequential(*list(resnet50.children())[:-2])

        if config.MODEL.POOLING.NAME == 'avg':
            self.globalpooling = nn.AdaptiveAvgPool2d(1)
        elif config.MODEL.POOLING.NAME == 'max':
            self.globalpooling = nn.AdaptiveMaxPool2d(1)
        elif config.MODEL.POOLING.NAME == 'gem':
            self.globalpooling = pooling.GeMPooling(p=config.MODEL.POOLING.P)
        elif config.MODEL.POOLING.NAME == 'maxavg':
            self.globalpooling = pooling.MaxAvgPooling()
        else:
            raise KeyError("Invalid pooling: '{}'".format(config.MODEL.POOLING.NAME))

        self.bn = nn.BatchNorm1d(config.MODEL.FEATURE_DIM)
        init.normal_(self.bn.weight.data, 1.0, 0.02)
        init.constant_(self.bn.bias.data, 0.0)
        
    def forward(self, x):
        x = self.base(x)
        x = self.globalpooling(x)
        x = x.view(x.size(0), -1)
        f = self.bn(x)

        return f

================================================
FILE: models/utils/c3d_blocks.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class APM(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim=3, temperature=4, contrastive_att=True):
        super(APM, self).__init__()

        self.time_dim = time_dim 
        self.temperature = temperature
        self.contrastive_att = contrastive_att

        padding = (0, 0, 0, 0, (time_dim-1)//2, (time_dim-1)//2)
        self.padding = nn.ConstantPad3d(padding, value=0)

        self.semantic_mapping = nn.Conv3d(in_channels, out_channels, \
                                          kernel_size=1, bias=False)          
        if self.contrastive_att:  
            self.x_mapping = nn.Conv3d(in_channels, out_channels, \
                                       kernel_size=1, bias=False)
            self.n_mapping = nn.Conv3d(in_channels, out_channels, \
                                       kernel_size=1, bias=False)
            self.contrastive_att_net = nn.Sequential(nn.Conv3d(out_channels, 1, \
                                kernel_size=1, bias=False), nn.Sigmoid())

    def forward(self, x):
        b, c, t, h, w = x.size()
        N = self.time_dim

        neighbor_time_index = torch.cat([(torch.arange(0,t)+i).unsqueeze(0) for i in range(N) if i!=N//2], dim=0).t().flatten().long()

        # feature map registration
        semantic = self.semantic_mapping(x) # (b, c/16, t, h, w)
        x_norm = F.normalize(semantic, p=2, dim=1) # (b, c/16, t, h, w)
        x_norm_padding = self.padding(x_norm) # (b, c/16, t+2, h, w)
        x_norm_expand = x_norm.unsqueeze(3).expand(-1, -1, -1, N-1, -1, -1).permute(0, 2, 3, 4, 5, 1).contiguous().view(-1, h*w, c//16) # (b*t*2, h*w, c/16) 
        neighbor_norm = x_norm_padding[:, :, neighbor_time_index, :, :].permute(0, 2, 1, 3, 4).contiguous().view(-1, c//16, h*w) # (b*t*2, c/16, h*w) 

        similarity = torch.matmul(x_norm_expand, neighbor_norm) * self.temperature # (b*t*2, h*w, h*w)
        similarity = F.softmax(similarity, dim=-1) # (b*t*2, h*w, h*w)

        x_padding = self.padding(x)
        neighbor = x_padding[:, :, neighbor_time_index, :, :].permute(0, 2, 3, 4, 1).contiguous().view(-1, h*w, c)
        neighbor_new = torch.matmul(similarity, neighbor).view(b, t*(N-1), h, w, c).permute(0, 4, 1, 2, 3) # (b, c, t*2, h, w)

        # contrastive attention
        if self.contrastive_att:
            x_att = self.x_mapping(x.unsqueeze(3).expand(-1, -1, -1, N-1, -1, -1).contiguous().view(b, c, (N-1)*t, h, w).detach())
            n_att = self.n_mapping(neighbor_new.detach())
            contrastive_att = self.contrastive_att_net(x_att * n_att)    
            neighbor_new = neighbor_new * contrastive_att

        # integrating feature maps
        x_offset = torch.zeros([b, c, N*t, h, w], dtype=x.data.dtype, device=x.device.type)
        x_index = torch.tensor([i for i in range(t*N) if i%N==N//2])
        neighbor_index = torch.tensor([i for i in range(t*N) if i%N!=N//2])
        x_offset[:, :, x_index, :, :] += x
        x_offset[:, :, neighbor_index, :, :] += neighbor_new

        return x_offset


class C2D(nn.Module):
    def __init__(self, conv2d, **kwargs):
        super(C2D, self).__init__()

        # conv3d kernel
        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
        stride = (1, conv2d.stride[0], conv2d.stride[0])
        padding = (0, conv2d.padding[0], conv2d.padding[1])
        self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                kernel_size=kernel_dim, padding=padding, \
                                stride=stride, bias=conv2d.bias)

        # init the parameters of conv3d
        weight_2d = conv2d.weight.data
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2)
        weight_3d[:, :, 0, :, :] = weight_2d
        self.conv3d.weight = nn.Parameter(weight_3d)
        self.conv3d.bias = conv2d.bias

    def forward(self, x):
        out = self.conv3d(x)

        return out


class I3D(nn.Module):
    def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
        super(I3D, self).__init__()

        # conv3d kernel
        kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])
        stride = (time_stride, conv2d.stride[0], conv2d.stride[0])
        padding = (time_dim//2, conv2d.padding[0], conv2d.padding[1])
        self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                kernel_size=kernel_dim, padding=padding, \
                                stride=stride, bias=conv2d.bias)

        # init the parameters of conv3d
        weight_2d = conv2d.weight.data
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
        middle_idx = time_dim // 2
        weight_3d[:, :, middle_idx, :, :] = weight_2d
        self.conv3d.weight = nn.Parameter(weight_3d)
        self.conv3d.bias = conv2d.bias

    def forward(self, x):
        out = self.conv3d(x)

        return out


class API3D(nn.Module):
    def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):
        super(API3D, self).__init__()

        self.APM = APM(conv2d.in_channels, conv2d.in_channels//16, \
                       time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)
        
        # conv3d kernel
        kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])
        stride = (time_stride*time_dim, conv2d.stride[0], conv2d.stride[0])
        padding = (0, conv2d.padding[0], conv2d.padding[1])
        self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                kernel_size=kernel_dim, padding=padding, \
                                stride=stride, bias=conv2d.bias)

        # init the parameters of conv3d
        weight_2d = conv2d.weight.data
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
        middle_idx = time_dim // 2
        weight_3d[:, :, middle_idx, :, :] = weight_2d
        self.conv3d.weight = nn.Parameter(weight_3d)
        self.conv3d.bias = conv2d.bias

    def forward(self, x):
        x_offset = self.APM(x)
        out = self.conv3d(x_offset)

        return out


class P3DA(nn.Module):
    def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
        super(P3DA, self).__init__()

        # spatial conv3d kernel
        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
        stride = (1, conv2d.stride[0], conv2d.stride[0])
        padding = (0, conv2d.padding[0], conv2d.padding[1])
        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                        kernel_size=kernel_dim, padding=padding, \
                                        stride=stride, bias=conv2d.bias)

        # init the parameters of spatial_conv3d
        weight_2d = conv2d.weight.data
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2)
        weight_3d[:, :, 0, :, :] = weight_2d
        self.spatial_conv3d.weight = nn.Parameter(weight_3d)
        self.spatial_conv3d.bias = conv2d.bias


        # temporal conv3d kernel
        kernel_dim = (time_dim, 1, 1)
        stride = (time_stride, 1, 1)
        padding = (time_dim//2, 0, 0)
        self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \
                                         kernel_size=kernel_dim, padding=padding, \
                                         stride=stride, bias=False)

        # init the parameters of temporal_conv3d
        weight_2d = torch.eye(conv2d.out_channels).unsqueeze(2).unsqueeze(2)
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
        middle_idx = time_dim // 2
        weight_3d[:, :, middle_idx, :, :] = weight_2d
        self.temporal_conv3d.weight = nn.Parameter(weight_3d)


    def forward(self, x):
        x = self.spatial_conv3d(x)
        out = self.temporal_conv3d(x)

        return out


class P3DB(nn.Module):
    def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
        super(P3DB, self).__init__()

        # spatial conv3d kernel
        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
        stride = (1, conv2d.stride[0], conv2d.stride[0])
        padding = (0, conv2d.padding[0], conv2d.padding[1])
        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                        kernel_size=kernel_dim, padding=padding, \
                                        stride=stride, bias=conv2d.bias)

        # init the parameters of spatial_conv3d
        weight_2d = conv2d.weight.data
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2)
        weight_3d[:, :, 0, :, :] = weight_2d
        self.spatial_conv3d.weight = nn.Parameter(weight_3d)
        self.spatial_conv3d.bias = conv2d.bias


        # temporal conv3d kernel
        kernel_dim = (time_dim, 1, 1)
        stride = (time_stride, conv2d.stride[0], conv2d.stride[0])
        padding = (time_dim//2, 0, 0)
        self.temporal_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                         kernel_size=kernel_dim, padding=padding, \
                                         stride=stride, bias=False)

        # init the parameters of temporal_conv3d
        nn.init.constant_(self.temporal_conv3d.weight, 0)


    def forward(self, x):
        # print(x.shape)
        out1 = self.spatial_conv3d(x)
        # print(out1.shape)
        out2 = self.temporal_conv3d(x)
        # print(out2.shape)
        out = out1 + out2

        return out


class P3DC(nn.Module):
    def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
        super(P3DC, self).__init__()

        # spatial conv3d kernel
        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
        stride = (1, conv2d.stride[0], conv2d.stride[0])
        padding = (0, conv2d.padding[0], conv2d.padding[1])
        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                        kernel_size=kernel_dim, padding=padding, \
                                        stride=stride, bias=conv2d.bias)

        # init the parameters of spatial_conv3d
        weight_2d = conv2d.weight.data
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2)
        weight_3d[:, :, 0, :, :] = weight_2d
        self.spatial_conv3d.weight = nn.Parameter(weight_3d)
        self.spatial_conv3d.bias = conv2d.bias


        # temporal conv3d kernel
        kernel_dim = (time_dim, 1, 1)
        stride = (time_stride, 1, 1)
        padding = (time_dim//2, 0, 0)
        self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \
                                         kernel_size=kernel_dim, padding=padding, \
                                         stride=stride, bias=False)

        # init the parameters of temporal_conv3d
        nn.init.constant_(self.temporal_conv3d.weight, 0)


    def forward(self, x):
        out = self.spatial_conv3d(x)
        residual = self.temporal_conv3d(out)
        out = out + residual

        return out


class APP3DA(nn.Module):
    def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):
        super(APP3DA, self).__init__()

        self.APM = APM(conv2d.out_channels, conv2d.out_channels//16, \
                       time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)

        # spatial conv3d kernel
        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
        stride = (1, conv2d.stride[0], conv2d.stride[0])
        padding = (0, conv2d.padding[0], conv2d.padding[1])
        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                        kernel_size=kernel_dim, padding=padding, \
                                        stride=stride, bias=conv2d.bias)

        # init the parameters of spatial_conv3d
        weight_2d = conv2d.weight.data
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2)
        weight_3d[:, :, 0, :, :] = weight_2d
        self.spatial_conv3d.weight = nn.Parameter(weight_3d)
        self.spatial_conv3d.bias = conv2d.bias


        # temporal conv3d kernel
        kernel_dim = (time_dim, 1, 1)
        stride = (time_stride*time_dim, 1, 1)
        padding = (0, 0, 0)
        self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \
                                         kernel_size=kernel_dim, padding=padding, \
                                         stride=stride, bias=False)

        # init the parameters of temporal_conv3d
        weight_2d = torch.eye(conv2d.out_channels).unsqueeze(2).unsqueeze(2)
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
        middle_idx = time_dim // 2
        weight_3d[:, :, middle_idx, :, :] = weight_2d
        self.temporal_conv3d.weight = nn.Parameter(weight_3d)


    def forward(self, x):
        x = self.spatial_conv3d(x)
        out = self.temporal_conv3d(self.APM(x))

        return out


class APP3DB(nn.Module):
    def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):
        super(APP3DB, self).__init__()

        self.APM = APM(conv2d.in_channels, conv2d.in_channels//16, \
                       time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)

        # spatial conv3d kernel
        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
        stride = (1, conv2d.stride[0], conv2d.stride[0])
        padding = (0, conv2d.padding[0], conv2d.padding[1])
        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                        kernel_size=kernel_dim, padding=padding, \
                                        stride=stride, bias=conv2d.bias)

        # init the parameters of spatial_conv3d
        weight_2d = conv2d.weight.data
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2)
        weight_3d[:, :, 0, :, :] = weight_2d
        self.spatial_conv3d.weight = nn.Parameter(weight_3d)
        self.spatial_conv3d.bias = conv2d.bias


        # temporal conv3d kernel
        kernel_dim = (time_dim, 1, 1)
        stride = (time_stride*time_dim, conv2d.stride[0], conv2d.stride[0])
        padding = (0, 0, 0)
        self.temporal_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                         kernel_size=kernel_dim, padding=padding, \
                                         stride=stride, bias=False)

        # init the parameters of temporal_conv3d
        nn.init.constant_(self.temporal_conv3d.weight, 0)


    def forward(self, x):
        out1 = self.spatial_conv3d(x)
        out2 = self.temporal_conv3d(self.APM(x))
        out = out1 + out2

        return out


class APP3DC(nn.Module):
    def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):
        super(APP3DC, self).__init__()

        self.APM = APM(conv2d.out_channels, conv2d.out_channels//16, \
                       time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)

        # spatial conv3d kernel
        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])
        stride = (1, conv2d.stride[0], conv2d.stride[0])
        padding = (0, conv2d.padding[0], conv2d.padding[1])
        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \
                                        kernel_size=kernel_dim, padding=padding, \
                                        stride=stride, bias=conv2d.bias)

        # init the parameters of spatial_conv3d
        weight_2d = conv2d.weight.data
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2)
        weight_3d[:, :, 0, :, :] = weight_2d
        self.spatial_conv3d.weight = nn.Parameter(weight_3d)
        self.spatial_conv3d.bias = conv2d.bias


        # temporal conv3d kernel
        kernel_dim = (time_dim, 1, 1)
        stride = (time_stride*time_dim, 1, 1)
        padding = (0, 0, 0)
        self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \
                                         kernel_size=kernel_dim, padding=padding, \
                                         stride=stride, bias=False)

        # init the parameters of temporal_conv3d
        nn.init.constant_(self.temporal_conv3d.weight, 0)


    def forward(self, x):
        out = self.spatial_conv3d(x)
        residual = self.temporal_conv3d(self.APM(out))
        out = out + residual

        return out


================================================
FILE: models/utils/inflate.py
================================================
# inflate 2D modules to 3D modules
import torch
import torch.nn as nn
from torch.nn import functional as F


def inflate_conv(conv2d,
                 time_dim=1,
                 time_padding=0,
                 time_stride=1,
                 time_dilation=1,
                 center=False):
    # To preserve activations, padding should be by continuity and not zero
    # or no padding in time dimension
    kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])
    padding = (time_padding, conv2d.padding[0], conv2d.padding[1])
    stride = (time_stride, conv2d.stride[0], conv2d.stride[0])
    dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1])
    conv3d = nn.Conv3d(
        conv2d.in_channels,
        conv2d.out_channels,
        kernel_dim,
        padding=padding,
        dilation=dilation,
        stride=stride)
    # Repeat filter time_dim times along time dimension
    weight_2d = conv2d.weight.data
    if center:
        weight_3d = torch.zeros(*weight_2d.shape)
        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
        middle_idx = time_dim // 2
        weight_3d[:, :, middle_idx, :, :] = weight_2d
    else:
        weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
        weight_3d = weight_3d / time_dim

    # Assign new params
    conv3d.weight = nn.Parameter(weight_3d)
    conv3d.bias = conv2d.bias
    return conv3d


def inflate_linear(linear2d, time_dim):
    """
    Args:
        time_dim: final time dimension of the features
    """
    linear3d = nn.Linear(linear2d.in_features * time_dim,
                               linear2d.out_features)
    weight3d = linear2d.weight.data.repeat(1, time_dim)
    weight3d = weight3d / time_dim

    linear3d.weight = nn.Parameter(weight3d)
    linear3d.bias = linear2d.bias
    return linear3d


def inflate_batch_norm(batch2d):
    # In pytorch 0.2.0 the 2d and 3d versions of batch norm
    # work identically except for the check that verifies the
    # input dimensions

    batch3d = nn.BatchNorm3d(batch2d.num_features)
    # retrieve 3d _check_input_dim function
    batch2d._check_input_dim = batch3d._check_input_dim
    return batch2d


def inflate_pool(pool2d,
                 time_dim=1,
                 time_padding=0,
                 time_stride=None,
                 time_dilation=1):
    kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size)
    padding = (time_padding, pool2d.padding, pool2d.padding)
    if time_stride is None:
        time_stride = time_dim
    stride = (time_stride, pool2d.stride, pool2d.stride)
    if isinstance(pool2d, nn.MaxPool2d):
        dilation = (time_dilation, pool2d.dilation, pool2d.dilation)
        pool3d = nn.MaxPool3d(
            kernel_dim,
            padding=padding,
            dilation=dilation,
            stride=stride,
            ceil_mode=pool2d.ceil_mode)
    elif isinstance(pool2d, nn.AvgPool2d):
        pool3d = nn.AvgPool3d(kernel_dim, stride=stride)
    else:
        raise ValueError(
            '{} is not among known pooling classes'.format(type(pool2d)))
    return pool3d


class MaxPool2dFor3dInput(nn.Module):
    """
    Since nn.MaxPool3d is nondeterministic operation, using fixed random seeds can't get consistent results.
    So we attempt to use max_pool2d to implement MaxPool3d with kernelsize (1, kernel_size, kernel_size).
    """
    def __init__(self, kernel_size, stride=None, padding=0, dilation=1):
        super().__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)
    def forward(self, x):
        b, c, t, h, w = x.size()
        x = x.permute(0, 2, 1, 3, 4).contiguous() # b, t, c, h, w
        x = x.view(b*t, c, h, w)
        # max pooling
        x = self.maxpool(x)
        _, _, h, w = x.size()
        x = x.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous()

        return x

================================================
FILE: models/utils/nonlocal_blocks.py
================================================
import torch
import math
from torch import nn
from torch.nn import functional as F
from models.utils import inflate


class NonLocalBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
        super(NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample
        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            # max_pool = inflate.MaxPool2dFor3dInput
            max_pool = nn.MaxPool3d
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool = nn.MaxPool2d
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool = nn.MaxPool1d
            bn = nn.BatchNorm1d

        self.g = conv_nd(self.in_channels, self.inter_channels,
                         kernel_size=1, stride=1, padding=0, bias=True)
        self.theta = conv_nd(self.in_channels, self.inter_channels,
                             kernel_size=1, stride=1, padding=0, bias=True)
        self.phi = conv_nd(self.in_channels, self.inter_channels,
                           kernel_size=1, stride=1, padding=0, bias=True)
        # if sub_sample:
        #     self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
        #     self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))
        if sub_sample:
            if dimension == 3:
                self.g = nn.Sequential(self.g, max_pool((1, 2, 2)))
                self.phi = nn.Sequential(self.phi, max_pool((1, 2, 2)))
            else:
                self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
                self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(self.inter_channels, self.in_channels,
                        kernel_size=1, stride=1, padding=0, bias=True),
                bn(self.in_channels)
            )
        else:
            self.W = conv_nd(self.inter_channels, self.in_channels,
                             kernel_size=1, stride=1, padding=0, bias=True)
        
        # init
        for m in self.modules():
            if isinstance(m, conv_nd):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, bn):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        if bn_layer:
            nn.init.constant_(self.W[1].weight.data, 0.0)
            nn.init.constant_(self.W[1].bias.data, 0.0)
        else:
            nn.init.constant_(self.W.weight.data, 0.0)
            nn.init.constant_(self.W.bias.data, 0.0)


    def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''
        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        f = F.softmax(f, dim=-1)

        y = torch.matmul(f, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        y = self.W(y)
        z = y + x

        return z


class NonLocalBlock1D(NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NonLocalBlock1D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=1, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NonLocalBlock2D(NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NonLocalBlock2D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=2, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NonLocalBlock3D(NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NonLocalBlock3D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=3, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


================================================
FILE: models/utils/pooling.py
================================================
import torch
from torch import nn
from torch.nn import functional as F


class GeMPooling(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), x.size()[2:]).pow(1./self.p)


class MaxAvgPooling(nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpooling = nn.AdaptiveMaxPool2d(1)
        self.avgpooling = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        max_f = self.maxpooling(x)
        avg_f = self.avgpooling(x)

        return torch.cat((max_f, avg_f), 1)
        

================================================
FILE: models/vid_resnet.py
================================================
import torchvision
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
from models.utils import inflate
from models.utils import c3d_blocks
from models.utils import nonlocal_blocks


__all__ = ['AP3DResNet50', 'AP3DNLResNet50', 'NLResNet50', 'C2DResNet50', 
           'I3DResNet50', 
          ] 


class Bottleneck3D(nn.Module):
    def __init__(self, bottleneck2d, block, inflate_time=False, temperature=4, contrastive_att=True):
        super().__init__()
        self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=1)
        self.bn1 = inflate.inflate_batch_norm(bottleneck2d.bn1)
        if inflate_time == True:
            self.conv2 = block(bottleneck2d.conv2, temperature=temperature, contrastive_att=contrastive_att)
        else:
            self.conv2 = inflate.inflate_conv(bottleneck2d.conv2, time_dim=1)
        self.bn2 = inflate.inflate_batch_norm(bottleneck2d.bn2)
        self.conv3 = inflate.inflate_conv(bottleneck2d.conv3, time_dim=1)
        self.bn3 = inflate.inflate_batch_norm(bottleneck2d.bn3)
        self.relu = nn.ReLU(inplace=True)

        if bottleneck2d.downsample is not None:
            self.downsample = self._inflate_downsample(bottleneck2d.downsample)
        else:
            self.downsample = None

    def _inflate_downsample(self, downsample2d, time_stride=1):
        downsample3d = nn.Sequential(
            inflate.inflate_conv(downsample2d[0], time_dim=1, 
                                 time_stride=time_stride),
            inflate.inflate_batch_norm(downsample2d[1]))
        return downsample3d

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet503D(nn.Module):
    def __init__(self, config, block, c3d_idx, nl_idx, **kwargs):
        super().__init__()
        self.block = block
        self.temperature = config.MODEL.AP3D.TEMPERATURE
        self.contrastive_att = config.MODEL.AP3D.CONTRACTIVE_ATT

        resnet2d = torchvision.models.resnet50(pretrained=True)
        if config.MODEL.RES4_STRIDE == 1:
            resnet2d.layer4[0].conv2.stride=(1, 1)
            resnet2d.layer4[0].downsample[0].stride=(1, 1) 

        self.conv1 = inflate.inflate_conv(resnet2d.conv1, time_dim=1)
        self.bn1 = inflate.inflate_batch_norm(resnet2d.bn1)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = inflate.inflate_pool(resnet2d.maxpool, time_dim=1)
        # self.maxpool = inflate.MaxPool2dFor3dInput(kernel_size=resnet2d.maxpool.kernel_size,
        #                                            stride=resnet2d.maxpool.stride,
        #                                            padding=resnet2d.maxpool.padding,
        #                                            dilation=resnet2d.maxpool.dilation)

        self.layer1 = self._inflate_reslayer(resnet2d.layer1, c3d_idx=c3d_idx[0], \
                                             nonlocal_idx=nl_idx[0], nonlocal_channels=256)
        self.layer2 = self._inflate_reslayer(resnet2d.layer2, c3d_idx=c3d_idx[1], \
                                             nonlocal_idx=nl_idx[1], nonlocal_channels=512)
        self.layer3 = self._inflate_reslayer(resnet2d.layer3, c3d_idx=c3d_idx[2], \
                                             nonlocal_idx=nl_idx[2], nonlocal_channels=1024)
        self.layer4 = self._inflate_reslayer(resnet2d.layer4, c3d_idx=c3d_idx[3], \
                                             nonlocal_idx=nl_idx[3], nonlocal_channels=2048)

        self.bn = nn.BatchNorm1d(2048)
        init.normal_(self.bn.weight.data, 1.0, 0.02)
        init.constant_(self.bn.bias.data, 0.0)

    def _inflate_reslayer(self, reslayer2d, c3d_idx, nonlocal_idx=[], nonlocal_channels=0):
        reslayers3d = []
        for i,layer2d in enumerate(reslayer2d):
            if i not in c3d_idx:
                layer3d = Bottleneck3D(layer2d, c3d_blocks.C2D, inflate_time=False)
            else:
                layer3d = Bottleneck3D(layer2d, self.block, inflate_time=True, \
                                       temperature=self.temperature, contrastive_att=self.contrastive_att)
            reslayers3d.append(layer3d)

            if i in nonlocal_idx:
                non_local_block = nonlocal_blocks.NonLocalBlock3D(nonlocal_channels, sub_sample=True)
                reslayers3d.append(non_local_block)

        return nn.Sequential(*reslayers3d)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        b, c, t, h, w = x.size()
        x = x.permute(0, 2, 1, 3, 4).contiguous()
        x = x.view(b*t, c, h, w)
        # spatial max pooling
        x = F.max_pool2d(x, x.size()[2:])
        x = x.view(b, t, -1)
        # temporal avg pooling
        x = x.mean(1)
        f = self.bn(x)

        return f


def C2DResNet50(config, **kwargs):
    c3d_idx = [[],[],[],[]]
    nl_idx = [[],[],[],[]]

    return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs)


def AP3DResNet50(config, **kwargs):
    c3d_idx = [[],[0, 2],[0, 2, 4],[]]
    nl_idx = [[],[],[],[]]

    return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs)


def I3DResNet50(config, **kwargs):
    c3d_idx = [[],[0, 2],[0, 2, 4],[]]
    nl_idx = [[],[],[],[]]

    return ResNet503D(config, c3d_blocks.I3D, c3d_idx, nl_idx, **kwargs)


def AP3DNLResNet50(config, **kwargs):
    c3d_idx = [[],[0, 2],[0, 2, 4],[]]
    nl_idx = [[],[1, 3],[1, 3, 5],[]]

    return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs)


def NLResNet50(config, **kwargs):
    c3d_idx = [[],[],[],[]]
    nl_idx = [[],[1, 3],[1, 3, 5],[]]

    return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs)


================================================
FILE: script.sh
================================================
# The code is builded with DistributedDataParallel. 
# Reprodecing the results in the paper should train the model on 2 GPUs.
# You can also train this model on single GPU and double config.DATA.TRAIN_BATCH in configs.
# For LTCC dataset
python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset ltcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 #
# For PRCC dataset
python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset prcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 #
# For VC-Clothes dataset. You should change the root path of '--resume' to your output path.
python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset vcclothes --cfg configs/res50_cels_cal.yaml --gpu 0,1 #
python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset vcclothes_cc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume /data/guxinqian/logs/vcclothes/res50-cels-cal/best_model.pth.tar #
python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset vcclothes_sc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume /data/guxinqian/logs/vcclothes/res50-cels-cal/best_model.pth.tar #
# For DeepChange dataset. Using amp can accelerate training.
python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset deepchange --cfg configs/res50_cels_cal_16x4.yaml --amp --gpu 0,1 #
# For LaST dataset. Using amp can accelerate training.
python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset last --cfg configs/res50_cels_cal_tri_16x4.yaml --amp --gpu 0,1 #
# For CCVID dataset
python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset ccvid --cfg configs/c2dres50_ce_cal.yaml --gpu 0,1 #

================================================
FILE: test.py
================================================
import time
import datetime
import logging
import numpy as np
import torch
import torch.nn.functional as F
from torch import distributed as dist
from tools.eval_metrics import evaluate, evaluate_with_clothes


VID_DATASET = ['ccvid']


def concat_all_gather(tensors, num_total_examples):
    '''
    Performs all_gather operation on the provided tensor list.
    '''
    outputs = []
    for tensor in tensors:
        tensor = tensor.cuda()
        tensors_gather = [tensor.clone() for _ in range(dist.get_world_size())]
        dist.all_gather(tensors_gather, tensor)
        output = torch.cat(tensors_gather, dim=0).cpu()
        # truncate the dummy elements added by DistributedInferenceSampler
        outputs.append(output[:num_total_examples])
    return outputs


@torch.no_grad()
def extract_img_feature(model, dataloader):
    features, pids, camids, clothes_ids = [], torch.tensor([]), torch.tensor([]), torch.tensor([])
    for batch_idx, (imgs, batch_pids, batch_camids, batch_clothes_ids) in enumerate(dataloader):
        flip_imgs = torch.flip(imgs, [3])
        imgs, flip_imgs = imgs.cuda(), flip_imgs.cuda()
        batch_features = model(imgs)
        batch_features_flip = model(flip_imgs)
        batch_features += batch_features_flip
        batch_features = F.normalize(batch_features, p=2, dim=1)

        features.append(batch_features.cpu())
        pids = torch.cat((pids, batch_pids.cpu()), dim=0)
        camids = torch.cat((camids, batch_camids.cpu()), dim=0)
        clothes_ids = torch.cat((clothes_ids, batch_clothes_ids.cpu()), dim=0)
    features = torch.cat(features, 0)

    return features, pids, camids, clothes_ids


@torch.no_grad()
def extract_vid_feature(model, dataloader, vid2clip_index, data_length):
    # In build_dataloader, each original test video is split into a series of equilong clips.
    # During test, we first extact features for all clips
    clip_features, clip_pids, clip_camids, clip_clothes_ids = [], torch.tensor([]), torch.tensor([]), torch.tensor([])
    for batch_idx, (vids, batch_pids, batch_camids, batch_clothes_ids) in enumerate(dataloader):
        if (batch_idx + 1) % 200==0:
            logger.info("{}/{}".format(batch_idx+1, len(dataloader)))
        vids = vids.cuda()
        batch_features = model(vids)
        clip_features.append(batch_features.cpu())
        clip_pids = torch.cat((clip_pids, batch_pids.cpu()), dim=0)
        clip_camids = torch.cat((clip_camids, batch_camids.cpu()), dim=0)
        clip_clothes_ids = torch.cat((clip_clothes_ids, batch_clothes_ids.cpu()), dim=0)
    clip_features = torch.cat(clip_features, 0)

    # Gather samples from different GPUs
    clip_features, clip_pids, clip_camids, clip_clothes_ids = \
        concat_all_gather([clip_features, clip_pids, clip_camids, clip_clothes_ids], data_length)

    # Use the averaged feature of all clips split from a video as the representation of this original full-length video
    features = torch.zeros(len(vid2clip_index), clip_features.size(1)).cuda()
    clip_features = clip_features.cuda()
    pids = torch.zeros(len(vid2clip_index))
    camids = torch.zeros(len(vid2clip_index))
    clothes_ids = torch.zeros(len(vid2clip_index))
    for i, idx in enumerate(vid2clip_index):
        features[i] = clip_features[idx[0] : idx[1], :].mean(0)
        features[i] = F.normalize(features[i], p=2, dim=0)
        pids[i] = clip_pids[idx[0]]
        camids[i] = clip_camids[idx[0]]
        clothes_ids[i] = clip_clothes_ids[idx[0]]
    features = features.cpu()

    return features, pids, camids, clothes_ids


def test(config, model, queryloader, galleryloader, dataset):
    logger = logging.getLogger('reid.test')
    since = time.time()
    model.eval()
    local_rank = dist.get_rank()
    # Extract features 
    if config.DATA.DATASET in VID_DATASET:
        qf, q_pids, q_camids, q_clothes_ids = extract_vid_feature(model, queryloader, 
                                                                  dataset.query_vid2clip_index,
                                                                  len(dataset.recombined_query))
        gf, g_pids, g_camids, g_clothes_ids = extract_vid_feature(model, galleryloader, 
                                                                  dataset.gallery_vid2clip_index,
                                                                  len(dataset.recombined_gallery))
    else:
        qf, q_pids, q_camids, q_clothes_ids = extract_img_feature(model, queryloader)
        gf, g_pids, g_camids, g_clothes_ids = extract_img_feature(model, galleryloader)
        # Gather samples from different GPUs
        torch.cuda.empty_cache()
        qf, q_pids, q_camids, q_clothes_ids = concat_all_gather([qf, q_pids, q_camids, q_clothes_ids], len(dataset.query))
        gf, g_pids, g_camids, g_clothes_ids = concat_all_gather([gf, g_pids, g_camids, g_clothes_ids], len(dataset.gallery))
    torch.cuda.empty_cache()
    time_elapsed = time.time() - since
    
    logger.info("Extracted features for query set, obtained {} matrix".format(qf.shape))    
    logger.info("Extracted features for gallery set, obtained {} matrix".format(gf.shape))
    logger.info('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    # Compute distance matrix between query and gallery
    since = time.time()
    m, n = qf.size(0), gf.size(0)
    distmat = torch.zeros((m,n))
    qf, gf = qf.cuda(), gf.cuda()
    # Cosine similarity
    for i in range(m):
        distmat[i] = (- torch.mm(qf[i:i+1], gf.t())).cpu()
    distmat = distmat.numpy()
    q_pids, q_camids, q_clothes_ids = q_pids.numpy(), q_camids.numpy(), q_clothes_ids.numpy()
    g_pids, g_camids, g_clothes_ids = g_pids.numpy(), g_camids.numpy(), g_clothes_ids.numpy()
    time_elapsed = time.time() - since
    logger.info('Distance computing in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    since = time.time()
    logger.info("Comput
Download .txt
gitextract_4atmcr4t/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── c2dres50_ce_cal.yaml
│   ├── default_img.py
│   ├── default_vid.py
│   ├── res50_cels_cal.yaml
│   ├── res50_cels_cal_16x4.yaml
│   └── res50_cels_cal_tri_16x4.yaml
├── data/
│   ├── __init__.py
│   ├── dataloader.py
│   ├── dataset_loader.py
│   ├── datasets/
│   │   ├── ccvid.py
│   │   ├── deepchange.py
│   │   ├── last.py
│   │   ├── ltcc.py
│   │   ├── prcc.py
│   │   └── vcclothes.py
│   ├── img_transforms.py
│   ├── samplers.py
│   ├── spatial_transforms.py
│   └── temporal_transforms.py
├── losses/
│   ├── __init__.py
│   ├── arcface_loss.py
│   ├── circle_loss.py
│   ├── clothes_based_adversarial_loss.py
│   ├── contrastive_loss.py
│   ├── cosface_loss.py
│   ├── cross_entropy_loss_with_label_smooth.py
│   ├── gather.py
│   └── triplet_loss.py
├── main.py
├── models/
│   ├── __init__.py
│   ├── classifier.py
│   ├── img_resnet.py
│   ├── utils/
│   │   ├── c3d_blocks.py
│   │   ├── inflate.py
│   │   ├── nonlocal_blocks.py
│   │   └── pooling.py
│   └── vid_resnet.py
├── script.sh
├── test.py
├── tools/
│   ├── eval_metrics.py
│   └── utils.py
└── train.py
Download .txt
SYMBOL INDEX (262 symbols across 37 files)

FILE: configs/default_img.py
  function update_config (line 128) | def update_config(config, args):
  function get_img_config (line 159) | def get_img_config(args):

FILE: configs/default_vid.py
  function update_config (line 134) | def update_config(config, args):
  function get_vid_config (line 165) | def get_vid_config(args):

FILE: data/__init__.py
  function get_names (line 30) | def get_names():
  function build_dataset (line 34) | def build_dataset(config):
  function build_img_transforms (line 49) | def build_img_transforms(config):
  function build_vid_transforms (line 67) | def build_vid_transforms(config):
  function build_dataloader (line 94) | def build_dataloader(config):

FILE: data/dataloader.py
  class BackgroundGenerator (line 29) | class BackgroundGenerator(threading.Thread):
    method __init__ (line 38) | def __init__(self, generator, local_rank, max_prefetch=10):
    method run (line 73) | def run(self):
    method next (line 81) | def next(self):
    method __next__ (line 88) | def __next__(self):
    method __iter__ (line 91) | def __iter__(self):
  class DataLoaderX (line 95) | class DataLoaderX(DataLoader):
    method __init__ (line 96) | def __init__(self, **kwargs):
    method __iter__ (line 102) | def __iter__(self):
    method _shutdown_background_thread (line 108) | def _shutdown_background_thread(self):
    method preload (line 124) | def preload(self):
    method __next__ (line 135) | def __next__(self):
    method shutdown (line 146) | def shutdown(self):

FILE: data/dataset_loader.py
  function read_image (line 8) | def read_image(img_path):
  class ImageDataset (line 24) | class ImageDataset(Dataset):
    method __init__ (line 26) | def __init__(self, dataset, transform=None):
    method __len__ (line 30) | def __len__(self):
    method __getitem__ (line 33) | def __getitem__(self, index):
  function pil_loader (line 41) | def pil_loader(path):
  function accimage_loader (line 48) | def accimage_loader(path):
  function get_default_image_loader (line 57) | def get_default_image_loader():
  function image_loader (line 65) | def image_loader(path):
  function video_loader (line 73) | def video_loader(img_paths, image_loader):
  function get_default_video_loader (line 84) | def get_default_video_loader():
  class VideoDataset (line 89) | class VideoDataset(Dataset):
    method __init__ (line 102) | def __init__(self,
    method __len__ (line 114) | def __len__(self):
    method __getitem__ (line 117) | def __getitem__(self, index):

FILE: data/datasets/ccvid.py
  class CCVID (line 14) | class CCVID(object):
    method __init__ (line 20) | def __init__(self, root='/data/datasets/', sampling_step=64, seq_len=1...
    method _check_before_run (line 79) | def _check_before_run(self):
    method _clothes2label_test (line 90) | def _clothes2label_test(self, query_path, gallery_path):
    method _process_data (line 114) | def _process_data(self, data_path, relabel=False, clothes2label=None):
    method _densesampling_for_trainingset (line 165) | def _densesampling_for_trainingset(self, dataset, sampling_step=64):
    method _recombination_for_testset (line 192) | def _recombination_for_testset(self, dataset, seq_len=16, stride=4):

FILE: data/datasets/deepchange.py
  class DeepChange (line 14) | class DeepChange(object):
    method __init__ (line 23) | def __init__(self, root='data', **kwargs):
    method _get_names (line 87) | def _get_names(self, fpath):
    method get_pid2label_and_clothes2label (line 95) | def get_pid2label_and_clothes2label(self, img_names1, img_names2=None):
    method _check_before_run (line 130) | def _check_before_run(self):
    method _process_dir (line 145) | def _process_dir(self, home_dir, img_names, clothes2label, pid2label=N...

FILE: data/datasets/last.py
  class LaST (line 14) | class LaST(object):
    method __init__ (line 25) | def __init__(self, root='data', **kwargs):
    method get_pid2label_and_clothes2label (line 71) | def get_pid2label_and_clothes2label(self, dir_path):
    method _check_before_run (line 102) | def _check_before_run(self):
    method _process_dir (line 117) | def _process_dir(self, dir_path, pid2label=None, clothes2label=None, r...

FILE: data/datasets/ltcc.py
  class LTCC (line 14) | class LTCC(object):
    method __init__ (line 23) | def __init__(self, root='data', **kwargs):
    method _check_before_run (line 61) | def _check_before_run(self):
    method _process_dir_train (line 72) | def _process_dir_train(self, dir_path):
    method _process_dir_test (line 108) | def _process_dir_test(self, query_path, gallery_path):

FILE: data/datasets/prcc.py
  class PRCC (line 14) | class PRCC(object):
    method __init__ (line 23) | def __init__(self, root='data', **kwargs):
    method _check_before_run (line 71) | def _check_before_run(self):
    method _process_dir_train (line 82) | def _process_dir_train(self, dir_path):
    method _process_dir_test (line 127) | def _process_dir_test(self, test_path):

FILE: data/datasets/vcclothes.py
  class VCClothes (line 14) | class VCClothes(object):
    method __init__ (line 23) | def __init__(self, root='data', mode='all', **kwargs):
    method _check_before_run (line 61) | def _check_before_run(self):
    method _process_dir_train (line 72) | def _process_dir_train(self):
    method _process_dir_test (line 109) | def _process_dir_test(self):
  function VCClothesSameClothes (line 178) | def VCClothesSameClothes(root='data', **kwargs):
  function VCClothesClothesChanging (line 182) | def VCClothesClothesChanging(root='data', **kwargs):

FILE: data/img_transforms.py
  class ResizeWithEqualScale (line 7) | class ResizeWithEqualScale(object):
    method __init__ (line 17) | def __init__(self, height, width, interpolation=Image.BILINEAR, fill_c...
    method __call__ (line 23) | def __call__(self, img):
  class RandomCroping (line 39) | class RandomCroping(object):
    method __init__ (line 46) | def __init__(self, p=0.5, interpolation=Image.BILINEAR):
    method __call__ (line 50) | def __call__(self, img):
  class RandomErasing (line 73) | class RandomErasing(object):
    method __init__ (line 88) | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, m...
    method __call__ (line 95) | def __call__(self, img):

FILE: data/samplers.py
  class RandomIdentitySampler (line 10) | class RandomIdentitySampler(Sampler):
    method __init__ (line 19) | def __init__(self, data_source, num_instances=4):
    method __iter__ (line 37) | def __iter__(self):
    method __len__ (line 60) | def __len__(self):
  class DistributedRandomIdentitySampler (line 64) | class DistributedRandomIdentitySampler(Sampler):
    method __init__ (line 81) | def __init__(self, data_source, num_instances=4,
    method __iter__ (line 124) | def __iter__(self):
    method __len__ (line 157) | def __len__(self):
    method set_epoch (line 160) | def set_epoch(self, epoch):
  class DistributedInferenceSampler (line 172) | class DistributedInferenceSampler(Sampler):
    method __init__ (line 184) | def __init__(self, dataset, rank=None, num_replicas=None):
    method __iter__ (line 200) | def __iter__(self):
    method __len__ (line 208) | def __len__(self):

FILE: data/spatial_transforms.py
  class Compose (line 15) | class Compose(object):
    method __init__ (line 28) | def __init__(self, transforms):
    method __call__ (line 31) | def __call__(self, img):
    method randomize_parameters (line 36) | def randomize_parameters(self):
  class ToTensor (line 41) | class ToTensor(object):
    method __init__ (line 47) | def __init__(self, norm_value=255):
    method __call__ (line 50) | def __call__(self, pic):
    method randomize_parameters (line 92) | def randomize_parameters(self):
  class Normalize (line 96) | class Normalize(object):
    method __init__ (line 108) | def __init__(self, mean, std):
    method __call__ (line 112) | def __call__(self, tensor):
    method randomize_parameters (line 124) | def randomize_parameters(self):
  class Scale (line 128) | class Scale(object):
    method __init__ (line 141) | def __init__(self, size, interpolation=Image.BILINEAR):
    method __call__ (line 146) | def __call__(self, img):
    method randomize_parameters (line 168) | def randomize_parameters(self):
  class RandomHorizontalFlip (line 172) | class RandomHorizontalFlip(object):
    method __call__ (line 175) | def __call__(self, img):
    method randomize_parameters (line 186) | def randomize_parameters(self):
  class RandomCrop (line 190) | class RandomCrop(object):
    method __init__ (line 199) | def __init__(self, size, p=0.5, interpolation=Image.BILINEAR):
    method __call__ (line 209) | def __call__(self, img):
    method randomize_parameters (line 228) | def randomize_parameters(self):
  class RandomErasing (line 234) | class RandomErasing(object):
    method __init__ (line 249) | def __init__(self, height=256, width=128, probability = 0.5, sl = 0.02...
    method __call__ (line 258) | def __call__(self, img):
    method randomize_parameters (line 270) | def randomize_parameters(self):

FILE: data/temporal_transforms.py
  class TemporalRandomCrop (line 5) | class TemporalRandomCrop(object):
    method __init__ (line 16) | def __init__(self, size=4, stride=8):
    method __call__ (line 20) | def __call__(self, frame_indices):
  class TemporalBeginCrop (line 49) | class TemporalBeginCrop(object):
    method __init__ (line 60) | def __init__(self, size=8, stride=4):
    method __call__ (line 64) | def __call__(self, frame_indices):
  class TemporalDivisionCrop (line 80) | class TemporalDivisionCrop(object):
    method __init__ (line 86) | def __init__(self, size=4):
    method __call__ (line 89) | def __call__(self, frame_indices):

FILE: losses/__init__.py
  function build_losses (line 11) | def build_losses(config, num_train_clothes):

FILE: losses/arcface_loss.py
  class ArcFaceLoss (line 7) | class ArcFaceLoss(nn.Module):
    method __init__ (line 17) | def __init__(self, scale=16, margin=0.1):
    method forward (line 22) | def forward(self, inputs, targets):

FILE: losses/circle_loss.py
  class CircleLoss (line 8) | class CircleLoss(nn.Module):
    method __init__ (line 18) | def __init__(self, scale=96, margin=0.3, **kwargs):
    method forward (line 23) | def forward(self, inputs, targets):
  class PairwiseCircleLoss (line 43) | class PairwiseCircleLoss(nn.Module):
    method __init__ (line 53) | def __init__(self, scale=48, margin=0.35, **kwargs):
    method forward (line 58) | def forward(self, inputs, targets):

FILE: losses/clothes_based_adversarial_loss.py
  class ClothesBasedAdversarialLoss (line 7) | class ClothesBasedAdversarialLoss(nn.Module):
    method __init__ (line 17) | def __init__(self, scale=16, epsilon=0.1):
    method forward (line 22) | def forward(self, inputs, targets, positive_mask):
  class ClothesBasedAdversarialLossWithMemoryBank (line 46) | class ClothesBasedAdversarialLossWithMemoryBank(nn.Module):
    method __init__ (line 59) | def __init__(self, num_clothes, feat_dim, momentum=0., scale=16, epsil...
    method forward (line 71) | def forward(self, inputs, targets, positive_mask):
    method _update_memory (line 109) | def _update_memory(self, features, labels):

FILE: losses/contrastive_loss.py
  class ContrastiveLoss (line 8) | class ContrastiveLoss(nn.Module):
    method __init__ (line 14) | def __init__(self, scale=16, **kwargs):
    method forward (line 18) | def forward(self, inputs, targets):

FILE: losses/cosface_loss.py
  class CosFaceLoss (line 8) | class CosFaceLoss(nn.Module):
    method __init__ (line 18) | def __init__(self, scale=16, margin=0.1, **kwargs):
    method forward (line 23) | def forward(self, inputs, targets):
  class PairwiseCosFaceLoss (line 37) | class PairwiseCosFaceLoss(nn.Module):
    method __init__ (line 47) | def __init__(self, scale=16, margin=0):
    method forward (line 52) | def forward(self, inputs, targets):

FILE: losses/cross_entropy_loss_with_label_smooth.py
  class CrossEntropyWithLabelSmooth (line 5) | class CrossEntropyWithLabelSmooth(nn.Module):
    method __init__ (line 16) | def __init__(self, epsilon=0.1):
    method forward (line 21) | def forward(self, inputs, targets):

FILE: losses/gather.py
  class GatherLayer (line 5) | class GatherLayer(torch.autograd.Function):
    method forward (line 9) | def forward(ctx, input):
    method backward (line 17) | def backward(ctx, *grads):

FILE: losses/triplet_loss.py
  class TripletLoss (line 8) | class TripletLoss(nn.Module):
    method __init__ (line 19) | def __init__(self, margin=0.3):
    method forward (line 24) | def forward(self, inputs, targets):

FILE: main.py
  function parse_option (line 30) | def parse_option():
  function main (line 53) | def main(config):

FILE: models/__init__.py
  function build_model (line 17) | def build_model(config, num_identities, num_clothes):

FILE: models/classifier.py
  class Classifier (line 11) | class Classifier(nn.Module):
    method __init__ (line 12) | def __init__(self, feature_dim, num_classes):
    method forward (line 18) | def forward(self, x):
  class NormalizedClassifier (line 24) | class NormalizedClassifier(nn.Module):
    method __init__ (line 25) | def __init__(self, feature_dim, num_classes):
    method forward (line 30) | def forward(self, x):

FILE: models/img_resnet.py
  class ResNet50 (line 7) | class ResNet50(nn.Module):
    method __init__ (line 8) | def __init__(self, config, **kwargs):
    method forward (line 32) | def forward(self, x):

FILE: models/utils/c3d_blocks.py
  class APM (line 6) | class APM(nn.Module):
    method __init__ (line 7) | def __init__(self, in_channels, out_channels, time_dim=3, temperature=...
    method forward (line 27) | def forward(self, x):
  class C2D (line 64) | class C2D(nn.Module):
    method __init__ (line 65) | def __init__(self, conv2d, **kwargs):
    method forward (line 84) | def forward(self, x):
  class I3D (line 90) | class I3D(nn.Module):
    method __init__ (line 91) | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
    method forward (line 111) | def forward(self, x):
  class API3D (line 117) | class API3D(nn.Module):
    method __init__ (line 118) | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, c...
    method forward (line 141) | def forward(self, x):
  class P3DA (line 148) | class P3DA(nn.Module):
    method __init__ (line 149) | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
    method forward (line 186) | def forward(self, x):
  class P3DB (line 193) | class P3DB(nn.Module):
    method __init__ (line 194) | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
    method forward (line 226) | def forward(self, x):
  class P3DC (line 237) | class P3DC(nn.Module):
    method __init__ (line 238) | def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):
    method forward (line 270) | def forward(self, x):
  class APP3DA (line 278) | class APP3DA(nn.Module):
    method __init__ (line 279) | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, c...
    method forward (line 319) | def forward(self, x):
  class APP3DB (line 326) | class APP3DB(nn.Module):
    method __init__ (line 327) | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, c...
    method forward (line 362) | def forward(self, x):
  class APP3DC (line 370) | class APP3DC(nn.Module):
    method __init__ (line 371) | def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, c...
    method forward (line 406) | def forward(self, x):

FILE: models/utils/inflate.py
  function inflate_conv (line 7) | def inflate_conv(conv2d,
  function inflate_linear (line 43) | def inflate_linear(linear2d, time_dim):
  function inflate_batch_norm (line 58) | def inflate_batch_norm(batch2d):
  function inflate_pool (line 69) | def inflate_pool(pool2d,
  class MaxPool2dFor3dInput (line 95) | class MaxPool2dFor3dInput(nn.Module):
    method __init__ (line 100) | def __init__(self, kernel_size, stride=None, padding=0, dilation=1):
    method forward (line 103) | def forward(self, x):

FILE: models/utils/nonlocal_blocks.py
  class NonLocalBlockND (line 8) | class NonLocalBlockND(nn.Module):
    method __init__ (line 9) | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_...
    method forward (line 82) | def forward(self, x):
  class NonLocalBlock1D (line 107) | class NonLocalBlock1D(NonLocalBlockND):
    method __init__ (line 108) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
  class NonLocalBlock2D (line 115) | class NonLocalBlock2D(NonLocalBlockND):
    method __init__ (line 116) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...
  class NonLocalBlock3D (line 123) | class NonLocalBlock3D(NonLocalBlockND):
    method __init__ (line 124) | def __init__(self, in_channels, inter_channels=None, sub_sample=True, ...

FILE: models/utils/pooling.py
  class GeMPooling (line 6) | class GeMPooling(nn.Module):
    method __init__ (line 7) | def __init__(self, p=3, eps=1e-6):
    method forward (line 12) | def forward(self, x):
  class MaxAvgPooling (line 16) | class MaxAvgPooling(nn.Module):
    method __init__ (line 17) | def __init__(self):
    method forward (line 22) | def forward(self, x):

FILE: models/vid_resnet.py
  class Bottleneck3D (line 15) | class Bottleneck3D(nn.Module):
    method __init__ (line 16) | def __init__(self, bottleneck2d, block, inflate_time=False, temperatur...
    method _inflate_downsample (line 34) | def _inflate_downsample(self, downsample2d, time_stride=1):
    method forward (line 41) | def forward(self, x):
  class ResNet503D (line 63) | class ResNet503D(nn.Module):
    method __init__ (line 64) | def __init__(self, config, block, c3d_idx, nl_idx, **kwargs):
    method _inflate_reslayer (line 97) | def _inflate_reslayer(self, reslayer2d, c3d_idx, nonlocal_idx=[], nonl...
    method forward (line 113) | def forward(self, x):
  function C2DResNet50 (line 137) | def C2DResNet50(config, **kwargs):
  function AP3DResNet50 (line 144) | def AP3DResNet50(config, **kwargs):
  function I3DResNet50 (line 151) | def I3DResNet50(config, **kwargs):
  function AP3DNLResNet50 (line 158) | def AP3DNLResNet50(config, **kwargs):
  function NLResNet50 (line 165) | def NLResNet50(config, **kwargs):

FILE: test.py
  function concat_all_gather (line 14) | def concat_all_gather(tensors, num_total_examples):
  function extract_img_feature (line 30) | def extract_img_feature(model, dataloader):
  function extract_vid_feature (line 50) | def extract_vid_feature(model, dataloader, vid2clip_index, data_length):
  function test (line 86) | def test(config, model, queryloader, galleryloader, dataset):
  function test_prcc (line 152) | def test_prcc(model, queryloader_same, queryloader_diff, galleryloader, ...

FILE: tools/eval_metrics.py
  function compute_ap_cmc (line 5) | def compute_ap_cmc(index, good_index, junk_index):
  function evaluate (line 30) | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids):
  function evaluate_with_clothes (line 75) | def evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q...

FILE: tools/utils.py
  function set_seed (line 13) | def set_seed(seed=None):
  function mkdir_if_missing (line 26) | def mkdir_if_missing(directory):
  function read_json (line 35) | def read_json(fpath):
  function write_json (line 41) | def write_json(obj, fpath):
  class AverageMeter (line 47) | class AverageMeter(object):
    method __init__ (line 52) | def __init__(self):
    method reset (line 55) | def reset(self):
    method update (line 61) | def update(self, val, n=1):
  function save_checkpoint (line 68) | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
  function get_logger (line 114) | def get_logger(fpath, local_rank=0, name=''):

FILE: train.py
  function train_cal (line 9) | def train_cal(config, epoch, model, classifier, clothes_classifier, crit...
  function train_cal_with_memory (line 95) | def train_cal_with_memory(config, epoch, model, classifier, criterion_cl...
Condensed preview — 45 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (223K chars).
[
  {
    "path": ".gitignore",
    "chars": 1799,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 1332,
    "preview": "### A Simple Codebase for Clothes-Changing Person Re-identification.\n####  [Clothes-Changing Person Re-identification wi"
  },
  {
    "path": "configs/c2dres50_ce_cal.yaml",
    "chars": 86,
    "preview": "MODEL:\n  NAME: c2dres50\nLOSS:\n  CLA_LOSS: crossentropy\n  CAL: cal\nTAG: c2dres50-ce-cal"
  },
  {
    "path": "configs/default_img.py",
    "chars": 4914,
    "preview": "import os\nimport yaml\nfrom yacs.config import CfgNode as CN\n\n\n_C = CN()\n# ----------------------------------------------"
  },
  {
    "path": "configs/default_vid.py",
    "chars": 5269,
    "preview": "import os\nimport yaml\nfrom yacs.config import CfgNode as CN\n\n\n_C = CN()\n# ----------------------------------------------"
  },
  {
    "path": "configs/res50_cels_cal.yaml",
    "chars": 96,
    "preview": "MODEL:\n  NAME: resnet50\nLOSS:\n  CLA_LOSS: crossentropylabelsmooth\n  CAL: cal\nTAG: res50-cels-cal"
  },
  {
    "path": "configs/res50_cels_cal_16x4.yaml",
    "chars": 144,
    "preview": "MODEL:\n  NAME: resnet50\nDATA:\n  NUM_INSTANCES: 4\n  TRAIN_BATCH: 32\nLOSS:\n  CLA_LOSS: crossentropylabelsmooth\n  CAL: cal\n"
  },
  {
    "path": "configs/res50_cels_cal_tri_16x4.yaml",
    "chars": 207,
    "preview": "MODEL:\n  NAME: resnet50\nDATA:\n  NUM_INSTANCES: 4\n  TRAIN_BATCH: 32\nLOSS:\n  CLA_LOSS: crossentropylabelsmooth\n  PAIR_LOSS"
  },
  {
    "path": "data/__init__.py",
    "chars": 8635,
    "preview": "import data.img_transforms as T\nimport data.spatial_transforms as ST\nimport data.temporal_transforms as TT\nfrom torch.ut"
  },
  {
    "path": "data/dataloader.py",
    "chars": 5651,
    "preview": "# refer to: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/data/data_utils.py\n\nimport torch\nimport threading\n"
  },
  {
    "path": "data/dataset_loader.py",
    "chars": 4341,
    "preview": "import torch\nimport functools\nimport os.path as osp\nfrom PIL import Image\nfrom torch.utils.data import Dataset\n\n\ndef rea"
  },
  {
    "path": "data/datasets/ccvid.py",
    "chars": 11959,
    "preview": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path a"
  },
  {
    "path": "data/datasets/deepchange.py",
    "chars": 8333,
    "preview": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path a"
  },
  {
    "path": "data/datasets/last.py",
    "chars": 6511,
    "preview": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path a"
  },
  {
    "path": "data/datasets/ltcc.py",
    "chars": 6851,
    "preview": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path a"
  },
  {
    "path": "data/datasets/prcc.py",
    "chars": 8077,
    "preview": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path a"
  },
  {
    "path": "data/datasets/vcclothes.py",
    "chars": 7809,
    "preview": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path a"
  },
  {
    "path": "data/img_transforms.py",
    "chars": 4034,
    "preview": "from torchvision.transforms import *\nfrom PIL import Image\nimport random\nimport math\n\n\nclass ResizeWithEqualScale(object"
  },
  {
    "path": "data/samplers.py",
    "chars": 8152,
    "preview": "import copy\nimport math\nimport random\nimport numpy as np\nfrom torch import distributed as dist\nfrom collections import d"
  },
  {
    "path": "data/spatial_transforms.py",
    "chars": 9440,
    "preview": "import random\nimport math\nimport numbers\nimport collections\nimport numpy as np\nimport torch\nimport torchvision.transform"
  },
  {
    "path": "data/temporal_transforms.py",
    "chars": 3548,
    "preview": "import random\nimport numpy as np\n\n\nclass TemporalRandomCrop(object):\n    \"\"\"Temporally crop the given frame indices at a"
  },
  {
    "path": "losses/__init__.py",
    "chars": 2947,
    "preview": "from torch import nn\nfrom losses.cross_entropy_loss_with_label_smooth import CrossEntropyWithLabelSmooth\nfrom losses.tri"
  },
  {
    "path": "losses/arcface_loss.py",
    "chars": 1412,
    "preview": "import math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass ArcFaceLoss(nn.Module):\n    \"\"\" Ar"
  },
  {
    "path": "losses/circle_loss.py",
    "chars": 3439,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch import distributed as dist\nfrom losses.gath"
  },
  {
    "path": "losses/clothes_based_adversarial_loss.py",
    "chars": 5478,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom losses.gather import GatherLayer\n\n\nclass ClothesB"
  },
  {
    "path": "losses/contrastive_loss.py",
    "chars": 2013,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch import distributed as dist\nfrom losses.gath"
  },
  {
    "path": "losses/cosface_loss.py",
    "chars": 2935,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch import distributed as dist\nfrom losses.gath"
  },
  {
    "path": "losses/cross_entropy_loss_with_label_smooth.py",
    "chars": 1118,
    "preview": "import torch\nfrom torch import nn\n\n\nclass CrossEntropyWithLabelSmooth(nn.Module):\n    \"\"\" Cross entropy loss with label "
  },
  {
    "path": "losses/gather.py",
    "chars": 704,
    "preview": "import torch\nimport torch.distributed as dist\n\n\nclass GatherLayer(torch.autograd.Function):\n    \"\"\"Gather tensors from a"
  },
  {
    "path": "losses/triplet_loss.py",
    "chars": 1849,
    "preview": "import math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom losses.gather import GatherLayer\n\n\ncl"
  },
  {
    "path": "main.py",
    "chars": 9661,
    "preview": "import os\nimport sys\nimport time\nimport datetime\nimport argparse\nimport logging\nimport os.path as osp\nimport numpy as np"
  },
  {
    "path": "models/__init__.py",
    "chars": 1452,
    "preview": "import logging\nfrom models.classifier import Classifier, NormalizedClassifier\nfrom models.img_resnet import ResNet50\nfro"
  },
  {
    "path": "models/classifier.py",
    "chars": 975,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom torch.nn import Pa"
  },
  {
    "path": "models/img_resnet.py",
    "chars": 1365,
    "preview": "import torchvision\nfrom torch import nn\nfrom torch.nn import init\nfrom models.utils import pooling\n        \n\nclass ResNe"
  },
  {
    "path": "models/utils/c3d_blocks.py",
    "chars": 17073,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass APM(nn.Module):\n    def __init__(self, in_cha"
  },
  {
    "path": "models/utils/inflate.py",
    "chars": 3925,
    "preview": "# inflate 2D modules to 3D modules\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n\ndef inflate"
  },
  {
    "path": "models/utils/nonlocal_blocks.py",
    "chars": 4992,
    "preview": "import torch\nimport math\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom models.utils import inflate\n\n\ncl"
  },
  {
    "path": "models/utils/pooling.py",
    "chars": 694,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass GeMPooling(nn.Module):\n    def __init__(s"
  },
  {
    "path": "models/vid_resnet.py",
    "chars": 6188,
    "preview": "import torchvision\nimport torch.nn as nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom models.util"
  },
  {
    "path": "script.sh",
    "chars": 1862,
    "preview": "# The code is builded with DistributedDataParallel. \r\n# Reprodecing the results in the paper should train the model on 2"
  },
  {
    "path": "test.py",
    "chars": 10697,
    "preview": "import time\nimport datetime\nimport logging\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch im"
  },
  {
    "path": "tools/eval_metrics.py",
    "chars": 4850,
    "preview": "import logging\nimport numpy as np\n\n\ndef compute_ap_cmc(index, good_index, junk_index):\n    \"\"\" Compute AP and CMC for ea"
  },
  {
    "path": "tools/utils.py",
    "chars": 3386,
    "preview": "import os\nimport sys\nimport shutil\nimport errno\nimport json\nimport os.path as osp\nimport torch\nimport random\nimport logg"
  },
  {
    "path": "train.py",
    "chars": 6463,
    "preview": "import time\nimport datetime\nimport logging\nimport torch\nfrom apex import amp\nfrom tools.utils import AverageMeter\n\n\ndef "
  }
]

About this extraction

This page contains the full source code of the guxinqian/Simple-CCReID GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 45 files (209.0 KB), approximately 52.7k tokens, and a symbol index with 262 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!