[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "### A Simple Codebase for Clothes-Changing Person Re-identification.\n####  [Clothes-Changing Person Re-identification with RGB Modality Only (CVPR, 2022)](https://arxiv.org/abs/2204.06890)\n\n#### Requirements\n- Python 3.6\n- Pytorch 1.6.0\n- yacs\n- apex\n\n#### CCVID Dataset\n- [[BaiduYun]](https://pan.baidu.com/s/1W9yjqxS9qxfPUSu76JpE1g) password: q0q2\n- [[GoogleDrive]](https://drive.google.com/file/d/1vkZxm5v-aBXa_JEi23MMeW4DgisGtS4W/view?usp=sharing)\n\n#### Get Started\n- Replace `_C.DATA.ROOT` and `_C.OUTPUT` in `configs/default_img.py&default_vid.py`with your own `data path` and `output path`, respectively.\n- Run `script.sh`\n\n\n#### Citation\n\nIf you use our code/dataset in your research or wish to refer to the baseline results, please use the following BibTeX entry.\n    \n    @inproceedings{gu2022CAL,\n        title={Clothes-Changing Person Re-identification with RGB Modality Only},\n        author={Gu, Xinqian and Chang, Hong and Ma, Bingpeng and Bai, Shutao and Shan, Shiguang and Chen, Xilin},\n        booktitle={CVPR},\n        year={2022},\n    }\n\n#### Related Repos\n\n- [Simple-ReID](https://github.com/guxinqian/Simple-ReID)\n- [fast-reid](https://github.com/JDAI-CV/fast-reid)\n- [deep-person-reid](https://github.com/KaiyangZhou/deep-person-reid)\n- [Pytorch ReID](https://github.com/layumi/Person_reID_baseline_pytorch)\n\n"
  },
  {
    "path": "configs/c2dres50_ce_cal.yaml",
    "content": "MODEL:\n  NAME: c2dres50\nLOSS:\n  CLA_LOSS: crossentropy\n  CAL: cal\nTAG: c2dres50-ce-cal"
  },
  {
    "path": "configs/default_img.py",
    "content": "import os\nimport yaml\nfrom yacs.config import CfgNode as CN\n\n\n_C = CN()\n# -----------------------------------------------------------------------------\n# Data settings\n# -----------------------------------------------------------------------------\n_C.DATA = CN()\n# Root path for dataset directory\n_C.DATA.ROOT = '/home/guxinqian/data'\n# Dataset for evaluation\n_C.DATA.DATASET = 'ltcc'\n# Workers for dataloader\n_C.DATA.NUM_WORKERS = 4\n# Height of input image\n_C.DATA.HEIGHT = 384\n# Width of input image\n_C.DATA.WIDTH = 192\n# Batch size for training\n_C.DATA.TRAIN_BATCH = 32\n# Batch size for testing\n_C.DATA.TEST_BATCH = 128\n# The number of instances per identity for training sampler\n_C.DATA.NUM_INSTANCES = 8\n# -----------------------------------------------------------------------------\n# Augmentation settings\n# -----------------------------------------------------------------------------\n_C.AUG = CN()\n# Random crop prob\n_C.AUG.RC_PROB = 0.5\n# Random erase prob\n_C.AUG.RE_PROB = 0.5\n# Random flip prob\n_C.AUG.RF_PROB = 0.5\n# -----------------------------------------------------------------------------\n# Model settings\n# -----------------------------------------------------------------------------\n_C.MODEL = CN()\n# Model name\n_C.MODEL.NAME = 'resnet50'\n# The stride for laery4 in resnet\n_C.MODEL.RES4_STRIDE = 1\n# feature dim\n_C.MODEL.FEATURE_DIM = 4096\n# Model path for resuming\n_C.MODEL.RESUME = ''\n# Global pooling after the backbone\n_C.MODEL.POOLING = CN()\n# Choose in ['avg', 'max', 'gem', 'maxavg']\n_C.MODEL.POOLING.NAME = 'maxavg'\n# Initialized power for GeM pooling\n_C.MODEL.POOLING.P = 3\n# -----------------------------------------------------------------------------\n# Losses for training \n# -----------------------------------------------------------------------------\n_C.LOSS = CN()\n# Classification loss\n_C.LOSS.CLA_LOSS = 'crossentropy'\n# Clothes classification loss\n_C.LOSS.CLOTHES_CLA_LOSS = 'cosface'\n# Scale for classification loss\n_C.LOSS.CLA_S = 16.\n# Margin for classification loss\n_C.LOSS.CLA_M = 0.\n# Pairwise loss\n_C.LOSS.PAIR_LOSS = 'triplet'\n# The weight for pairwise loss\n_C.LOSS.PAIR_LOSS_WEIGHT = 0.0\n# Scale for pairwise loss\n_C.LOSS.PAIR_S = 16.\n# Margin for pairwise loss\n_C.LOSS.PAIR_M = 0.3\n# Clothes-based adversarial loss\n_C.LOSS.CAL = 'cal'\n# Epsilon for clothes-based adversarial loss\n_C.LOSS.EPSILON = 0.1\n# Momentum for clothes-based adversarial loss with memory bank\n_C.LOSS.MOMENTUM = 0.\n# -----------------------------------------------------------------------------\n# Training settings\n# -----------------------------------------------------------------------------\n_C.TRAIN = CN()\n_C.TRAIN.START_EPOCH = 0\n_C.TRAIN.MAX_EPOCH = 60\n# Start epoch for clothes classification\n_C.TRAIN.START_EPOCH_CC = 25\n# Start epoch for adversarial training\n_C.TRAIN.START_EPOCH_ADV = 25\n# Optimizer\n_C.TRAIN.OPTIMIZER = CN()\n_C.TRAIN.OPTIMIZER.NAME = 'adam'\n# Learning rate\n_C.TRAIN.OPTIMIZER.LR = 0.00035\n_C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4\n# LR scheduler\n_C.TRAIN.LR_SCHEDULER = CN()\n# Stepsize to decay learning rate\n_C.TRAIN.LR_SCHEDULER.STEPSIZE = [20, 40]\n# LR decay rate, used in StepLRScheduler\n_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1\n# Using amp for training\n_C.TRAIN.AMP = False\n# -----------------------------------------------------------------------------\n# Testing settings\n# -----------------------------------------------------------------------------\n_C.TEST = CN()\n# Perform evaluation after every N epochs (set to -1 to test after training)\n_C.TEST.EVAL_STEP = 5\n# Start to evaluate after specific epoch\n_C.TEST.START_EVAL = 0\n# -----------------------------------------------------------------------------\n# Misc\n# -----------------------------------------------------------------------------\n# Fixed random seed\n_C.SEED = 1\n# Perform evaluation only\n_C.EVAL_MODE = False\n# GPU device ids for CUDA_VISIBLE_DEVICES\n_C.GPU = '0'\n# Path to output folder, overwritten by command line argument\n_C.OUTPUT = '/data/guxinqian/logs/'\n# Tag of experiment, overwritten by command line argument\n_C.TAG = 'res50-ce-cal'\n\n\ndef update_config(config, args):\n    config.defrost()\n    config.merge_from_file(args.cfg)\n\n    # merge from specific arguments\n    if args.root:\n        config.DATA.ROOT = args.root\n    if args.output:\n        config.OUTPUT = args.output\n\n    if args.resume:\n        config.MODEL.RESUME = args.resume\n    if args.eval:\n        config.EVAL_MODE = True\n    \n    if args.tag:\n        config.TAG = args.tag\n\n    if args.dataset:\n        config.DATA.DATASET = args.dataset\n    if args.gpu:\n        config.GPU = args.gpu\n    if args.amp:\n        config.TRAIN.AMP = True\n\n    # output folder\n    config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG)\n\n    config.freeze()\n\n\ndef get_img_config(args):\n    \"\"\"Get a yacs CfgNode object with default values.\"\"\"\n    config = _C.clone()\n    update_config(config, args)\n\n    return config\n"
  },
  {
    "path": "configs/default_vid.py",
    "content": "import os\nimport yaml\nfrom yacs.config import CfgNode as CN\n\n\n_C = CN()\n# -----------------------------------------------------------------------------\n# Data settings\n# -----------------------------------------------------------------------------\n_C.DATA = CN()\n# Root path for dataset directory\n_C.DATA.ROOT = '/home/guxinqian/data'\n# Dataset for evaluation\n_C.DATA.DATASET = 'ccvid'\n# Whether split each full-length video in the training set into some clips\n_C.DATA.DENSE_SAMPLING = True\n# Sampling step of dense sampling for training set\n_C.DATA.SAMPLING_STEP = 64\n# Workers for dataloader\n_C.DATA.NUM_WORKERS = 4\n# Height of input image\n_C.DATA.HEIGHT = 256\n# Width of input image\n_C.DATA.WIDTH = 128\n# Batch size for training\n_C.DATA.TRAIN_BATCH = 16\n# Batch size for testing\n_C.DATA.TEST_BATCH = 128\n# The number of instances per identity for training sampler\n_C.DATA.NUM_INSTANCES = 4\n# -----------------------------------------------------------------------------\n# Augmentation settings\n# -----------------------------------------------------------------------------\n_C.AUG = CN()\n# Random erase prob\n_C.AUG.RE_PROB = 0.0\n# Temporal sampling mode for training, 'tsn' or 'stride'\n_C.AUG.TEMPORAL_SAMPLING_MODE = 'stride'\n# Sequence length of each input video clip\n_C.AUG.SEQ_LEN = 8\n# Sampling stride of each input video clip\n_C.AUG.SAMPLING_STRIDE = 4\n# -----------------------------------------------------------------------------\n# Model settings\n# -----------------------------------------------------------------------------\n_C.MODEL = CN()\n# Model name. All supported model can be seen in models/__init__.py\n_C.MODEL.NAME = 'c2dres50'\n# The stride for laery4 in resnet\n_C.MODEL.RES4_STRIDE = 1\n# feature dim\n_C.MODEL.FEATURE_DIM = 2048\n# Model path for resuming\n_C.MODEL.RESUME = ''\n# Params for AP3D\n_C.MODEL.AP3D = CN()\n# Temperature for APM\n_C.MODEL.AP3D.TEMPERATURE = 4\n# Contrastive attention\n_C.MODEL.AP3D.CONTRACTIVE_ATT = True\n# -----------------------------------------------------------------------------\n# Losses for training \n# -----------------------------------------------------------------------------\n_C.LOSS = CN()\n# Classification loss\n_C.LOSS.CLA_LOSS = 'crossentropy'\n# Clothes classification loss\n_C.LOSS.CLOTHES_CLA_LOSS = 'cosface'\n# Scale for classification loss\n_C.LOSS.CLA_S = 16.\n# Margin for classification loss\n_C.LOSS.CLA_M = 0.\n# Pairwise loss\n_C.LOSS.PAIR_LOSS = 'triplet'\n# The weight for pairwise loss\n_C.LOSS.PAIR_LOSS_WEIGHT = 0.0\n# Scale for pairwise loss\n_C.LOSS.PAIR_S = 16.\n# Margin for pairwise loss\n_C.LOSS.PAIR_M = 0.3\n# Clothes-based adversarial loss\n_C.LOSS.CAL = 'cal'\n# Epsilon for clothes-based adversarial loss\n_C.LOSS.EPSILON = 0.1\n# Momentum for clothes-based adversarial loss with memory bank\n_C.LOSS.MOMENTUM = 0.\n# -----------------------------------------------------------------------------\n# Training settings\n# -----------------------------------------------------------------------------\n_C.TRAIN = CN()\n_C.TRAIN.START_EPOCH = 0\n_C.TRAIN.MAX_EPOCH = 150\n# Start epoch for clothes classification\n_C.TRAIN.START_EPOCH_CC = 50\n# Start epoch for adversarial training\n_C.TRAIN.START_EPOCH_ADV = 50\n# Optimizer\n_C.TRAIN.OPTIMIZER = CN()\n_C.TRAIN.OPTIMIZER.NAME = 'adam'\n# Learning rate\n_C.TRAIN.OPTIMIZER.LR = 0.00035\n_C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4\n# LR scheduler\n_C.TRAIN.LR_SCHEDULER = CN()\n# Stepsize to decay learning rate\n_C.TRAIN.LR_SCHEDULER.STEPSIZE = [40, 80, 120]\n# LR decay rate, used in StepLRScheduler\n_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1\n# Using amp for training\n_C.TRAIN.AMP = False\n# -----------------------------------------------------------------------------\n# Testing settings\n# -----------------------------------------------------------------------------\n_C.TEST = CN()\n# Perform evaluation after every N epochs (set to -1 to test after training)\n_C.TEST.EVAL_STEP = 10\n# Start to evaluate after specific epoch\n_C.TEST.START_EVAL = 0\n# -----------------------------------------------------------------------------\n# Misc\n# -----------------------------------------------------------------------------\n# Fixed random seed\n_C.SEED = 1\n# Perform evaluation only\n_C.EVAL_MODE = False\n# GPU device ids for CUDA_VISIBLE_DEVICES\n_C.GPU = '0, 1'\n# Path to output folder, overwritten by command line argument\n_C.OUTPUT = '/data/guxinqian/logs/'\n# Tag of experiment, overwritten by command line argument\n_C.TAG = 'res50-ce-cal'\n\n\ndef update_config(config, args):\n    config.defrost()\n    config.merge_from_file(args.cfg)\n\n    # merge from specific arguments\n    if args.root:\n        config.DATA.ROOT = args.root\n    if args.output:\n        config.OUTPUT = args.output\n\n    if args.resume:\n        config.MODEL.RESUME = args.resume\n    if args.eval:\n        config.EVAL_MODE = True\n    \n    if args.tag:\n        config.TAG = args.tag\n\n    if args.dataset:\n        config.DATA.DATASET = args.dataset\n    if args.gpu:\n        config.GPU = args.gpu\n    if args.amp:\n        config.TRAIN.AMP = True\n\n    # output folder\n    config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG)\n\n    config.freeze()\n\n\ndef get_vid_config(args):\n    \"\"\"Get a yacs CfgNode object with default values.\"\"\"\n    config = _C.clone()\n    update_config(config, args)\n\n    return config\n"
  },
  {
    "path": "configs/res50_cels_cal.yaml",
    "content": "MODEL:\n  NAME: resnet50\nLOSS:\n  CLA_LOSS: crossentropylabelsmooth\n  CAL: cal\nTAG: res50-cels-cal"
  },
  {
    "path": "configs/res50_cels_cal_16x4.yaml",
    "content": "MODEL:\n  NAME: resnet50\nDATA:\n  NUM_INSTANCES: 4\n  TRAIN_BATCH: 32\nLOSS:\n  CLA_LOSS: crossentropylabelsmooth\n  CAL: cal\nTAG: res50-cels-cal-16x4"
  },
  {
    "path": "configs/res50_cels_cal_tri_16x4.yaml",
    "content": "MODEL:\n  NAME: resnet50\nDATA:\n  NUM_INSTANCES: 4\n  TRAIN_BATCH: 32\nLOSS:\n  CLA_LOSS: crossentropylabelsmooth\n  PAIR_LOSS: triplet\n  CAL: cal\n  PAIR_M: 0.3\n  PAIR_LOSS_WEIGHT: 1.0\nTAG: res50-cels-cal-tri-16x4"
  },
  {
    "path": "data/__init__.py",
    "content": "import data.img_transforms as T\nimport data.spatial_transforms as ST\nimport data.temporal_transforms as TT\nfrom torch.utils.data import DataLoader\nfrom data.dataloader import DataLoaderX\nfrom data.dataset_loader import ImageDataset, VideoDataset\nfrom data.samplers import DistributedRandomIdentitySampler, DistributedInferenceSampler\nfrom data.datasets.ltcc import LTCC\nfrom data.datasets.prcc import PRCC\nfrom data.datasets.last import LaST\nfrom data.datasets.ccvid import CCVID\nfrom data.datasets.deepchange import DeepChange\nfrom data.datasets.vcclothes import VCClothes, VCClothesSameClothes, VCClothesClothesChanging\n\n\n__factory = {\n    'ltcc': LTCC,\n    'prcc': PRCC,\n    'vcclothes': VCClothes,\n    'vcclothes_sc': VCClothesSameClothes,\n    'vcclothes_cc': VCClothesClothesChanging,\n    'last': LaST,\n    'ccvid': CCVID,\n    'deepchange': DeepChange,\n}\n\nVID_DATASET = ['ccvid']\n\n\ndef get_names():\n    return list(__factory.keys())\n\n\ndef build_dataset(config):\n    if config.DATA.DATASET not in __factory.keys():\n        raise KeyError(\"Invalid dataset, got '{}', but expected to be one of {}\".format(name, __factory.keys()))\n\n    if config.DATA.DATASET in VID_DATASET:\n        dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT, \n                                                 sampling_step=config.DATA.SAMPLING_STEP,\n                                                 seq_len=config.AUG.SEQ_LEN, \n                                                 stride=config.AUG.SAMPLING_STRIDE)\n    else:\n        dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT)\n\n    return dataset\n\n\ndef build_img_transforms(config):\n    transform_train = T.Compose([\n        T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)),\n        T.RandomCroping(p=config.AUG.RC_PROB),\n        T.RandomHorizontalFlip(p=config.AUG.RF_PROB),\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n        T.RandomErasing(probability=config.AUG.RE_PROB)\n    ])\n    transform_test = T.Compose([\n        T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)),\n        T.ToTensor(),\n        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n    ])\n\n    return transform_train, transform_test\n\n\ndef build_vid_transforms(config):\n    spatial_transform_train = ST.Compose([\n        ST.Scale((config.DATA.HEIGHT, config.DATA.WIDTH), interpolation=3),\n        ST.RandomHorizontalFlip(),\n        ST.ToTensor(),\n        ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n        ST.RandomErasing(height=config.DATA.HEIGHT, width=config.DATA.WIDTH, probability=config.AUG.RE_PROB)\n    ])\n    spatial_transform_test = ST.Compose([\n        ST.Scale((config.DATA.HEIGHT, config.DATA.WIDTH), interpolation=3),\n        ST.ToTensor(),\n        ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n    ])\n\n    if config.AUG.TEMPORAL_SAMPLING_MODE == 'tsn':\n        temporal_transform_train = TT.TemporalDivisionCrop(size=config.AUG.SEQ_LEN)\n    elif config.AUG.TEMPORAL_SAMPLING_MODE == 'stride':\n        temporal_transform_train = TT.TemporalRandomCrop(size=config.AUG.SEQ_LEN, \n                                                         stride=config.AUG.SAMPLING_STRIDE)\n    else:\n        raise KeyError(\"Invalid temporal sempling mode '{}'\".format(config.AUG.TEMPORAL_SAMPLING_MODE))\n\n    temporal_transform_test = None\n\n    return spatial_transform_train, spatial_transform_test, temporal_transform_train, temporal_transform_test\n\n\ndef build_dataloader(config):\n    dataset = build_dataset(config)\n    # video dataset\n    if config.DATA.DATASET in VID_DATASET:\n        spatial_transform_train, spatial_transform_test, temporal_transform_train, temporal_transform_test = build_vid_transforms(config)\n\n        if config.DATA.DENSE_SAMPLING:\n            train_sampler = DistributedRandomIdentitySampler(dataset.train_dense, \n                                                             num_instances=config.DATA.NUM_INSTANCES, \n                                                             seed=config.SEED)\n            # split each original training video into a series of short videos and sample one clip for each short video during training\n            trainloader = DataLoaderX(\n                dataset=VideoDataset(dataset.train_dense, spatial_transform_train, temporal_transform_train),\n                sampler=train_sampler,\n                batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS,\n                pin_memory=True, drop_last=True)\n        else:\n            train_sampler = DistributedRandomIdentitySampler(dataset.train, \n                                                             num_instances=config.DATA.NUM_INSTANCES, \n                                                             seed=config.SEED)\n            # sample one clip for each original training video during training\n            trainloader = DataLoaderX(\n                dataset=VideoDataset(dataset.train, spatial_transform_train, temporal_transform_train),\n                sampler=train_sampler,\n                batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS,\n                pin_memory=True, drop_last=True)\n        \n        # split each original test video into a series of clips and use the averaged feature of all clips as its representation\n        queryloader = DataLoaderX(\n            dataset=VideoDataset(dataset.recombined_query, spatial_transform_test, temporal_transform_test),\n            sampler=DistributedInferenceSampler(dataset.recombined_query),\n            batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,\n            pin_memory=True, drop_last=False, shuffle=False)\n        galleryloader = DataLoaderX(\n            dataset=VideoDataset(dataset.recombined_gallery, spatial_transform_test, temporal_transform_test),\n            sampler=DistributedInferenceSampler(dataset.recombined_gallery),\n            batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,\n            pin_memory=True, drop_last=False, shuffle=False)\n\n        return trainloader, queryloader, galleryloader, dataset, train_sampler\n    # image dataset\n    else:\n        transform_train, transform_test = build_img_transforms(config)\n        train_sampler = DistributedRandomIdentitySampler(dataset.train, \n                                                         num_instances=config.DATA.NUM_INSTANCES, \n                                                         seed=config.SEED)\n        trainloader = DataLoaderX(dataset=ImageDataset(dataset.train, transform=transform_train),\n                                 sampler=train_sampler,\n                                 batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS,\n                                 pin_memory=True, drop_last=True)\n\n        galleryloader = DataLoaderX(dataset=ImageDataset(dataset.gallery, transform=transform_test),\n                                   sampler=DistributedInferenceSampler(dataset.gallery),\n                                   batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,\n                                   pin_memory=True, drop_last=False, shuffle=False)\n\n        if config.DATA.DATASET == 'prcc':\n            queryloader_same = DataLoaderX(dataset=ImageDataset(dataset.query_same, transform=transform_test),\n                                     sampler=DistributedInferenceSampler(dataset.query_same),\n                                     batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,\n                                     pin_memory=True, drop_last=False, shuffle=False)\n            queryloader_diff = DataLoaderX(dataset=ImageDataset(dataset.query_diff, transform=transform_test),\n                                     sampler=DistributedInferenceSampler(dataset.query_diff),\n                                     batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,\n                                     pin_memory=True, drop_last=False, shuffle=False)\n\n            return trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler\n        else:\n            queryloader = DataLoaderX(dataset=ImageDataset(dataset.query, transform=transform_test),\n                                     sampler=DistributedInferenceSampler(dataset.query),\n                                     batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS,\n                                     pin_memory=True, drop_last=False, shuffle=False)\n\n            return trainloader, queryloader, galleryloader, dataset, train_sampler\n\n    \n\n    \n"
  },
  {
    "path": "data/dataloader.py",
    "content": "# refer to: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/data/data_utils.py\n\nimport torch\nimport threading\nimport queue\nfrom torch.utils.data import DataLoader\nfrom torch import distributed as dist\n\n\n\"\"\"\n#based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch\nThis is a single-function package that transforms arbitrary generator into a background-thead generator that \nprefetches several batches of data in a parallel background thead.\n\nThis is useful if you have a computationally heavy process (CPU or GPU) that \niteratively processes minibatches from the generator while the generator \nconsumes some other resource (disk IO / loading from database / more CPU if you have unused cores). \n\nBy default these two processes will constantly wait for one another to finish. If you make generator work in \nprefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time.\nWe personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc.\n\nQuick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb\nThis package contains this object\n - BackgroundGenerator(any_other_generator[,max_prefetch = something])\n\"\"\"\n\n\nclass BackgroundGenerator(threading.Thread):\n    \"\"\"\n    the usage is below\n    >> for batch in BackgroundGenerator(my_minibatch_iterator):\n    >>    doit()\n    More details are written in the BackgroundGenerator doc\n    >> help(BackgroundGenerator)\n    \"\"\"\n\n    def __init__(self, generator, local_rank, max_prefetch=10):\n        \"\"\"\n        This function transforms generator into a background-thead generator.\n        :param generator: generator or genexp or any\n        It can be used with any minibatch generator.\n\n        It is quite lightweight, but not entirely weightless.\n        Using global variables inside generator is not recommended (may raise GIL and zero-out the\n        benefit of having a background thread.)\n        The ideal use case is when everything it requires is store inside it and everything it\n        outputs is passed through queue.\n\n        There's no restriction on doing weird stuff, reading/writing files, retrieving\n        URLs [or whatever] wlilst iterating.\n\n        :param max_prefetch: defines, how many iterations (at most) can background generator keep\n        stored at any moment of time.\n        Whenever there's already max_prefetch batches stored in queue, the background process will halt until\n        one of these batches is dequeued.\n\n        !Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator!\n\n        Setting max_prefetch to -1 lets it store as many batches as it can, which will work\n        slightly (if any) faster, but will require storing\n        all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size\n        unless dequeued quickly enough.\n        \"\"\"\n        super().__init__()\n        self.queue = queue.Queue(max_prefetch)\n        self.generator = generator\n        self.local_rank = local_rank\n        self.daemon = True\n        self.exit_event = threading.Event()\n        self.start()\n\n    def run(self):\n        torch.cuda.set_device(self.local_rank)\n        for item in self.generator:\n            if self.exit_event.is_set():\n                break\n            self.queue.put(item)\n        self.queue.put(None)\n\n    def next(self):\n        next_item = self.queue.get()\n        if next_item is None:\n            raise StopIteration\n        return next_item\n\n    # Python 3 compatibility\n    def __next__(self):\n        return self.next()\n\n    def __iter__(self):\n        return self\n\n\nclass DataLoaderX(DataLoader):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        local_rank = dist.get_rank()\n        self.stream = torch.cuda.Stream(local_rank)  # create a new cuda stream in each process\n        self.local_rank = local_rank\n\n    def __iter__(self):\n        self.iter = super().__iter__()\n        self.iter = BackgroundGenerator(self.iter, self.local_rank)\n        self.preload()\n        return self\n\n    def _shutdown_background_thread(self):\n        if not self.iter.is_alive():\n            # avoid re-entrance or ill-conditioned thread state\n            return\n\n        # Set exit event to True for background threading stopping\n        self.iter.exit_event.set()\n\n        # Exhaust all remaining elements, so that the queue becomes empty,\n        # and the thread should quit\n        for _ in self.iter:\n            pass\n\n        # Waiting for background thread to quit\n        self.iter.join()\n\n    def preload(self):\n        self.batch = next(self.iter, None)\n        if self.batch is None:\n            return None\n        with torch.cuda.stream(self.stream):\n            # if isinstance(self.batch[0], torch.Tensor):\n            #     self.batch[0] = self.batch[0].to(device=self.local_rank, non_blocking=True)\n            for k, v in enumerate(self.batch):\n                if isinstance(self.batch[k], torch.Tensor):\n                    self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)\n\n    def __next__(self):\n        torch.cuda.current_stream().wait_stream(\n            self.stream\n        )  # wait tensor to put on GPU\n        batch = self.batch\n        if batch is None:\n            raise StopIteration\n        self.preload()\n        return batch\n\n    # Signal for shutting down background thread\n    def shutdown(self):\n        # If the dataloader is to be freed, shutdown its BackgroundGenerator\n        self._shutdown_background_thread()\n"
  },
  {
    "path": "data/dataset_loader.py",
    "content": "import torch\nimport functools\nimport os.path as osp\nfrom PIL import Image\nfrom torch.utils.data import Dataset\n\n\ndef read_image(img_path):\n    \"\"\"Keep reading image until succeed.\n    This can avoid IOError incurred by heavy IO process.\"\"\"\n    got_img = False\n    if not osp.exists(img_path):\n        raise IOError(\"{} does not exist\".format(img_path))\n    while not got_img:\n        try:\n            img = Image.open(img_path).convert('RGB')\n            got_img = True\n        except IOError:\n            print(\"IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.\".format(img_path))\n            pass\n    return img\n\n\nclass ImageDataset(Dataset):\n    \"\"\"Image Person ReID Dataset\"\"\"\n    def __init__(self, dataset, transform=None):\n        self.dataset = dataset\n        self.transform = transform\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, index):\n        img_path, pid, camid, clothes_id = self.dataset[index]\n        img = read_image(img_path)\n        if self.transform is not None:\n            img = self.transform(img)\n        return img, pid, camid, clothes_id\n\n\ndef pil_loader(path):\n    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)\n    with open(path, 'rb') as f:\n        with Image.open(f) as img:\n            return img.convert('RGB')\n\n\ndef accimage_loader(path):\n    try:\n        import accimage\n        return accimage.Image(path)\n    except IOError:\n        # Potentially a decoding problem, fall back to PIL.Image\n        return pil_loader(path)\n\n\ndef get_default_image_loader():\n    from torchvision import get_image_backend\n    if get_image_backend() == 'accimage':\n        return accimage_loader\n    else:\n        return pil_loader\n\n\ndef image_loader(path):\n    from torchvision import get_image_backend\n    if get_image_backend() == 'accimage':\n        return accimage_loader(path)\n    else:\n        return pil_loader(path)\n\n\ndef video_loader(img_paths, image_loader):\n    video = []\n    for image_path in img_paths:\n        if osp.exists(image_path):\n            video.append(image_loader(image_path))\n        else:\n            return video\n\n    return video\n\n\ndef get_default_video_loader():\n    image_loader = get_default_image_loader()\n    return functools.partial(video_loader, image_loader=image_loader)\n\n\nclass VideoDataset(Dataset):\n    \"\"\"Video Person ReID Dataset.\n    Note:\n        Batch data has shape N x C x T x H x W\n    Args:\n        dataset (list): List with items (img_paths, pid, camid)\n        temporal_transform (callable, optional): A function/transform that  takes in a list of frame indices\n            and returns a transformed version\n        target_transform (callable, optional): A function/transform that takes in the\n            target and transforms it.\n        loader (callable, optional): A function to load an video given its path and frame indices.\n    \"\"\"\n\n    def __init__(self, \n                 dataset, \n                 spatial_transform=None,\n                 temporal_transform=None,\n                 get_loader=get_default_video_loader,\n                 cloth_changing=True):\n        self.dataset = dataset\n        self.spatial_transform = spatial_transform\n        self.temporal_transform = temporal_transform\n        self.loader = get_loader()\n        self.cloth_changing = cloth_changing\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, index):\n        \"\"\"\n        Args:\n            index (int): Index\n\n        Returns:\n            tuple: (clip, pid, camid) where pid is identity of the clip.\n        \"\"\"\n        if self.cloth_changing:\n            img_paths, pid, camid, clothes_id = self.dataset[index]\n        else:\n            img_paths, pid, camid = self.dataset[index]\n\n        if self.temporal_transform is not None:\n            img_paths = self.temporal_transform(img_paths)\n\n        clip = self.loader(img_paths)\n\n        if self.spatial_transform is not None:\n            self.spatial_transform.randomize_parameters()\n            clip = [self.spatial_transform(img) for img in clip]\n\n        # trans T x C x H x W to C x T x H x W\n        clip = torch.stack(clip, 0).permute(1, 0, 2, 3)\n\n        if self.cloth_changing:\n            return clip, pid, camid, clothes_id\n        else:\n            return clip, pid, camid"
  },
  {
    "path": "data/datasets/ccvid.py",
    "content": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path as osp\nfrom scipy.io import loadmat\nfrom tools.utils import mkdir_if_missing, write_json, read_json\n\n\nclass CCVID(object):\n    \"\"\" CCVID\n\n    Reference:\n        Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022.\n    \"\"\"\n    def __init__(self, root='/data/datasets/', sampling_step=64, seq_len=16, stride=4, **kwargs):\n        self.root = osp.join(root, 'CCVID')\n        self.train_path = osp.join(self.root, 'train.txt')\n        self.query_path = osp.join(self.root, 'query.txt')\n        self.gallery_path = osp.join(self.root, 'gallery.txt')\n        self._check_before_run()\n \n        train, num_train_tracklets, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes, _ = \\\n            self._process_data(self.train_path, relabel=True)\n        clothes2label = self._clothes2label_test(self.query_path, self.gallery_path)\n        query, num_query_tracklets, num_query_pids, num_query_imgs, num_query_clothes, _, _ = \\\n            self._process_data(self.query_path, relabel=False, clothes2label=clothes2label)\n        gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs, num_gallery_clothes, _, _ = \\\n            self._process_data(self.gallery_path, relabel=False, clothes2label=clothes2label)\n\n        # slice each full-length video in the trainingset into more video clip\n        train_dense = self._densesampling_for_trainingset(train, sampling_step)\n        # In the test stage, each video sample is divided into a series of equilong video clips with a pre-defined stride.\n        recombined_query, query_vid2clip_index = self._recombination_for_testset(query, seq_len=seq_len, stride=stride)\n        recombined_gallery, gallery_vid2clip_index = self._recombination_for_testset(gallery, seq_len=seq_len, stride=stride)\n       \n        num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs \n        min_num = np.min(num_imgs_per_tracklet)\n        max_num = np.max(num_imgs_per_tracklet)\n        avg_num = np.mean(num_imgs_per_tracklet)\n\n        num_total_pids = num_train_pids + num_gallery_pids\n        num_total_clothes = num_train_clothes + len(clothes2label)\n        num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets \n\n        logger = logging.getLogger('reid.dataset')\n        logger.info(\"=> CCVID loaded\")\n        logger.info(\"Dataset statistics:\")\n        logger.info(\"  ---------------------------------------------\")\n        logger.info(\"  subset       | # ids | # tracklets | # clothes\")\n        logger.info(\"  ---------------------------------------------\")\n        logger.info(\"  train        | {:5d} | {:11d} | {:9d}\".format(num_train_pids, num_train_tracklets, num_train_clothes))\n        logger.info(\"  train_dense  | {:5d} | {:11d} | {:9d}\".format(num_train_pids, len(train_dense), num_train_clothes))\n        logger.info(\"  query        | {:5d} | {:11d} | {:9d}\".format(num_query_pids, num_query_tracklets, num_query_clothes))\n        logger.info(\"  gallery      | {:5d} | {:11d} | {:9d}\".format(num_gallery_pids, num_gallery_tracklets, num_gallery_clothes))\n        logger.info(\"  ---------------------------------------------\")\n        logger.info(\"  total        | {:5d} | {:11d} | {:9d}\".format(num_total_pids, num_total_tracklets, num_total_clothes))\n        logger.info(\"  number of images per tracklet: {} ~ {}, average {:.1f}\".format(min_num, max_num, avg_num))\n        logger.info(\"  ---------------------------------------------\")\n\n        self.train = train\n        self.train_dense = train_dense\n        self.query = query\n        self.gallery = gallery\n\n        self.recombined_query = recombined_query\n        self.recombined_gallery = recombined_gallery\n        self.query_vid2clip_index = query_vid2clip_index\n        self.gallery_vid2clip_index = gallery_vid2clip_index\n\n        self.num_train_pids = num_train_pids\n        self.num_train_clothes = num_train_clothes\n        self.pid2clothes = pid2clothes\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.root):\n            raise RuntimeError(\"'{}' is not available\".format(self.root))\n        if not osp.exists(self.train_path):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_path))\n        if not osp.exists(self.query_path):\n            raise RuntimeError(\"'{}' is not available\".format(self.query_path))\n        if not osp.exists(self.gallery_path):\n            raise RuntimeError(\"'{}' is not available\".format(self.gallery_path))\n\n    def _clothes2label_test(self, query_path, gallery_path):\n        pid_container = set()\n        clothes_container = set()\n        with open(query_path, 'r') as f:\n            for line in f:\n                new_line = line.rstrip()\n                tracklet_path, pid, clothes_label = new_line.split()\n                clothes = '{}_{}'.format(pid, clothes_label)\n                pid_container.add(pid)\n                clothes_container.add(clothes)\n        with open(gallery_path, 'r') as f:\n            for line in f:\n                new_line = line.rstrip()\n                tracklet_path, pid, clothes_label = new_line.split()\n                clothes = '{}_{}'.format(pid, clothes_label)\n                pid_container.add(pid)\n                clothes_container.add(clothes)\n        pid_container = sorted(pid_container)\n        clothes_container = sorted(clothes_container)\n        pid2label = {pid:label for label, pid in enumerate(pid_container)}\n        clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)}\n\n        return clothes2label\n\n    def _process_data(self, data_path, relabel=False, clothes2label=None):\n        tracklet_path_list = []\n        pid_container = set()\n        clothes_container = set()\n        with open(data_path, 'r') as f:\n            for line in f:\n                new_line = line.rstrip()\n                tracklet_path, pid, clothes_label = new_line.split()\n                tracklet_path_list.append((tracklet_path, pid, clothes_label))\n                clothes = '{}_{}'.format(pid, clothes_label)\n                pid_container.add(pid)\n                clothes_container.add(clothes)\n        pid_container = sorted(pid_container)\n        clothes_container = sorted(clothes_container)\n        pid2label = {pid:label for label, pid in enumerate(pid_container)}\n        if clothes2label is None:\n            clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)}\n\n        num_tracklets = len(tracklet_path_list)\n        num_pids = len(pid_container)\n        num_clothes = len(clothes_container)\n\n        tracklets = []\n        num_imgs_per_tracklet = []\n        pid2clothes = np.zeros((num_pids, len(clothes2label)))\n\n        for tracklet_path, pid, clothes_label in tracklet_path_list:\n            img_paths = glob.glob(osp.join(self.root, tracklet_path, '*')) \n            img_paths.sort()\n\n            clothes = '{}_{}'.format(pid, clothes_label)\n            clothes_id = clothes2label[clothes]\n            pid2clothes[pid2label[pid], clothes_id] = 1\n            if relabel:\n                pid = pid2label[pid]\n            else:\n                pid = int(pid)\n            session = tracklet_path.split('/')[0]\n            cam = tracklet_path.split('_')[1]\n            if session == 'session3':\n                camid = int(cam) + 12\n            else:\n                camid = int(cam)\n\n            num_imgs_per_tracklet.append(len(img_paths))\n            tracklets.append((img_paths, pid, camid, clothes_id))\n\n        num_tracklets = len(tracklets)\n\n        return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet, num_clothes, pid2clothes, clothes2label\n\n    def _densesampling_for_trainingset(self, dataset, sampling_step=64):\n        ''' Split all videos in training set into lots of clips for dense sampling.\n\n        Args:\n            dataset (list): input dataset, each video is organized as (img_paths, pid, camid, clothes_id)\n            sampling_step (int): sampling step for dense sampling\n\n        Returns:\n            new_dataset (list): output dataset\n        '''\n        new_dataset = []\n        for (img_paths, pid, camid, clothes_id) in dataset:\n            if sampling_step != 0:\n                num_sampling = len(img_paths)//sampling_step\n                if num_sampling == 0:\n                    new_dataset.append((img_paths, pid, camid, clothes_id))\n                else:\n                    for idx in range(num_sampling):\n                        if idx == num_sampling - 1:\n                            new_dataset.append((img_paths[idx*sampling_step:], pid, camid, clothes_id))\n                        else:\n                            new_dataset.append((img_paths[idx*sampling_step : (idx+1)*sampling_step], pid, camid, clothes_id))\n            else:\n                new_dataset.append((img_paths, pid, camid, clothes_id))\n\n        return new_dataset\n\n    def _recombination_for_testset(self, dataset, seq_len=16, stride=4):\n        ''' Split all videos in test set into lots of equilong clips.\n\n        Args:\n            dataset (list): input dataset, each video is organized as (img_paths, pid, camid, clothes_id)\n            seq_len (int): sequence length of each output clip\n            stride (int): temporal sampling stride\n\n        Returns:\n            new_dataset (list): output dataset with lots of equilong clips\n            vid2clip_index (list): a list contains the start and end clip index of each original video\n        '''\n        new_dataset = []\n        vid2clip_index = np.zeros((len(dataset), 2), dtype=int)\n        for idx, (img_paths, pid, camid, clothes_id) in enumerate(dataset):\n            # start index\n            vid2clip_index[idx, 0] = len(new_dataset)\n            # process the sequence that can be divisible by seq_len*stride\n            for i in range(len(img_paths)//(seq_len*stride)):\n                for j in range(stride):\n                    begin_idx = i * (seq_len * stride) + j\n                    end_idx = (i + 1) * (seq_len * stride)\n                    clip_paths = img_paths[begin_idx : end_idx : stride]\n                    assert(len(clip_paths) == seq_len)\n                    new_dataset.append((clip_paths, pid, camid, clothes_id))\n            # process the remaining sequence that can't be divisible by seq_len*stride        \n            if len(img_paths)%(seq_len*stride) != 0:\n                # reducing stride\n                new_stride = (len(img_paths)%(seq_len*stride)) // seq_len\n                for i in range(new_stride):\n                    begin_idx = len(img_paths) // (seq_len*stride) * (seq_len*stride) + i\n                    end_idx = len(img_paths) // (seq_len*stride) * (seq_len*stride) + seq_len * new_stride\n                    clip_paths = img_paths[begin_idx : end_idx : new_stride]\n                    assert(len(clip_paths) == seq_len)\n                    new_dataset.append((clip_paths, pid, camid, clothes_id))\n                # process the remaining sequence that can't be divisible by seq_len\n                if len(img_paths) % seq_len != 0:\n                    clip_paths = img_paths[len(img_paths)//seq_len*seq_len:]\n                    # loop padding\n                    while len(clip_paths) < seq_len:\n                        for index in clip_paths:\n                            if len(clip_paths) >= seq_len:\n                                break\n                            clip_paths.append(index)\n                    assert(len(clip_paths) == seq_len)\n                    new_dataset.append((clip_paths, pid, camid, clothes_id))\n            # end index\n            vid2clip_index[idx, 1] = len(new_dataset)\n            assert((vid2clip_index[idx, 1]-vid2clip_index[idx, 0]) == math.ceil(len(img_paths)/seq_len))\n\n        return new_dataset, vid2clip_index.tolist()\n\n"
  },
  {
    "path": "data/datasets/deepchange.py",
    "content": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path as osp\nfrom scipy.io import loadmat\nfrom tools.utils import mkdir_if_missing, write_json, read_json\n     \n\nclass DeepChange(object):\n    \"\"\" DeepChange\n\n    Reference:\n        Xu et al. DeepChange: A Long-Term Person Re-Identification Benchmark. arXiv:2105.14685, 2021.\n\n    URL: https://github.com/PengBoXiangShang/deepchange\n    \"\"\"\n    dataset_dir = 'DeepChangeDataset'\n    def __init__(self, root='data', **kwargs):\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'train-set')\n        self.train_list = osp.join(self.dataset_dir, 'train-set-bbox.txt')\n        self.val_query_dir = osp.join(self.dataset_dir, 'val-set-query')\n        self.val_query_list = osp.join(self.dataset_dir, 'val-set-query-bbox.txt')\n        self.val_gallery_dir = osp.join(self.dataset_dir, 'val-set-gallery')\n        self.val_gallery_list = osp.join(self.dataset_dir, 'val-set-gallery-bbox.txt')\n        self.test_query_dir = osp.join(self.dataset_dir, 'test-set-query')\n        self.test_query_list = osp.join(self.dataset_dir, 'test-set-query-bbox.txt')\n        self.test_gallery_dir = osp.join(self.dataset_dir, 'test-set-gallery')\n        self.test_gallery_list = osp.join(self.dataset_dir, 'test-set-gallery-bbox.txt')\n        self._check_before_run()\n\n        train_names = self._get_names(self.train_list)\n        val_query_names = self._get_names(self.val_query_list)\n        val_gallery_names = self._get_names(self.val_gallery_list)\n        test_query_names = self._get_names(self.test_query_list)\n        test_gallery_names = self._get_names(self.test_gallery_list)\n\n        pid2label, clothes2label, pid2clothes = self.get_pid2label_and_clothes2label(train_names)\n        train, num_train_pids, num_train_clothes = self._process_dir(self.train_dir, train_names, clothes2label, pid2label=pid2label)\n\n        pid2label, clothes2label = self.get_pid2label_and_clothes2label(val_query_names, val_gallery_names)\n        val_query, num_val_query_pids, num_val_query_clothes  = self._process_dir(self.val_query_dir, val_query_names, clothes2label)\n        val_gallery, num_val_gallery_pids, num_val_gallery_clothes = self._process_dir(self.val_gallery_dir, val_gallery_names, clothes2label)\n        num_val_pids = len(pid2label)\n        num_val_clothes = len(clothes2label)\n\n        pid2label, clothes2label = self.get_pid2label_and_clothes2label(test_query_names, test_gallery_names)\n        test_query, num_test_query_pids, num_test_query_clothes = self._process_dir(self.test_query_dir, test_query_names, clothes2label)\n        test_gallery, num_test_gallery_pids, num_test_gallery_clothes = self._process_dir(self.test_gallery_dir, test_gallery_names, clothes2label)\n        num_test_pids = len(pid2label)\n        num_test_clothes = len(clothes2label)\n\n        num_total_pids = num_train_pids + num_val_pids + num_test_pids\n        num_total_clothes = num_train_clothes + num_val_clothes + num_test_clothes\n        num_total_imgs = len(train) + len(val_query) + len(val_gallery) + len(test_query) + len(test_gallery)\n\n        logger = logging.getLogger('reid.dataset')\n        logger.info(\"=> DeepChange loaded\")\n        logger.info(\"Dataset statistics:\")\n        logger.info(\"  --------------------------------------------\")\n        logger.info(\"  subset        | # ids | # images | # clothes\")\n        logger.info(\"  ----------------------------------------\")\n        logger.info(\"  train         | {:5d} | {:8d} | {:9d} \".format(num_train_pids, len(train), num_train_clothes))\n        logger.info(\"  query(val)    | {:5d} | {:8d} | {:9d} \".format(num_val_query_pids, len(val_query), num_val_query_clothes))\n        logger.info(\"  gallery(val)  | {:5d} | {:8d} | {:9d} \".format(num_val_gallery_pids, len(val_gallery), num_val_gallery_clothes))\n        logger.info(\"  query         | {:5d} | {:8d} | {:9d} \".format(num_test_query_pids, len(test_query), num_test_query_clothes))\n        logger.info(\"  gallery       | {:5d} | {:8d} | {:9d} \".format(num_test_gallery_pids, len(test_gallery), num_test_gallery_clothes))\n        logger.info(\"  --------------------------------------------\")\n        logger.info(\"  total         | {:5d} | {:8d} | {:9d} \".format(num_total_pids, num_total_imgs, num_total_clothes))\n        logger.info(\"  --------------------------------------------\")\n\n        self.train = train\n        self.val_query = val_query\n        self.val_gallery = val_gallery\n        self.query = test_query\n        self.gallery = test_gallery\n\n        self.num_train_pids = num_train_pids\n        self.num_train_clothes = num_train_clothes\n        self.pid2clothes = pid2clothes\n\n    def _get_names(self, fpath):\n        names = []\n        with open(fpath, 'r') as f:\n            for line in f:\n                new_line = line.rstrip()\n                names.append(new_line)\n        return names\n\n    def get_pid2label_and_clothes2label(self, img_names1, img_names2=None):\n        if img_names2 is not None:\n            img_names = img_names1 + img_names2\n        else:\n            img_names = img_names1\n\n        pid_container = set()\n        clothes_container = set()\n        for img_name in img_names:\n            names = img_name.split('.')[0].split('_')\n            clothes = names[0] + names[2]\n            pid = int(names[0][1:])\n            pid_container.add(pid)\n            clothes_container.add(clothes)\n        pid_container = sorted(pid_container)\n        clothes_container = sorted(clothes_container)\n        pid2label = {pid: label for label, pid in enumerate(pid_container)}\n        clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)}\n\n        if img_names2 is not None:\n            return pid2label, clothes2label\n\n        num_pids = len(pid_container)\n        num_clothes = len(clothes_container)\n        pid2clothes = np.zeros((num_pids, num_clothes))\n        for img_name in img_names:\n            names = img_name.split('.')[0].split('_')\n            clothes = names[0] + names[2]\n            pid = int(names[0][1:])\n            pid = pid2label[pid]\n            clothes_id = clothes2label[clothes]\n            pid2clothes[pid, clothes_id] = 1\n\n        return pid2label, clothes2label, pid2clothes\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.val_query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.val_query_dir))\n        if not osp.exists(self.val_gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.val_gallery_dir))\n        if not osp.exists(self.test_query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.test_query_dir))\n        if not osp.exists(self.test_gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.test_gallery_dir))\n\n    def _process_dir(self, home_dir, img_names, clothes2label, pid2label=None):\n        dataset = []\n        pid_container = set()\n        clothes_container = set()\n        for img_name in img_names:\n            img_path = osp.join(home_dir, img_name.split(',')[0])\n            names = img_name.split('.')[0].split('_')\n            tracklet_id = int(img_name.split(',')[1])\n            clothes = names[0] + names[2]\n            clothes_id = clothes2label[clothes]\n            clothes_container.add(clothes_id)\n            pid = int(names[0][1:])\n            pid_container.add(pid)\n            camid = int(names[1][1:])\n            if pid2label is not None:\n                pid = pid2label[pid]\n            # on DeepChange, we allow the true matches coming from the same camera \n            # but different tracklets as query following the original paper.\n            # So we use tracklet_id to replace camid for each sample.\n            dataset.append((img_path, pid, tracklet_id, clothes_id))\n        num_pids = len(pid_container)\n        num_clothes = len(clothes_container)\n\n        return dataset, num_pids, num_clothes"
  },
  {
    "path": "data/datasets/last.py",
    "content": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path as osp\nfrom scipy.io import loadmat\nfrom tools.utils import mkdir_if_missing, write_json, read_json\n\n\nclass LaST(object):\n    \"\"\" LaST\n\n    Reference:\n        Shu et al. Large-Scale Spatio-Temporal Person Re-identification: Algorithm and Benchmark. arXiv:2105.15076, 2021.\n\n    URL: https://github.com/shuxjweb/last\n\n    Note that LaST does not provide the clothes label for val and test set.\n    \"\"\"\n    dataset_dir = \"last\"\n    def __init__(self, root='data', **kwargs):\n        super(LaST, self).__init__()\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'train')\n        self.val_query_dir = osp.join(self.dataset_dir, 'val', 'query')\n        self.val_gallery_dir = osp.join(self.dataset_dir, 'val', 'gallery')\n        self.test_query_dir = osp.join(self.dataset_dir, 'test', 'query')\n        self.test_gallery_dir = osp.join(self.dataset_dir, 'test', 'gallery')\n        self._check_before_run()\n\n        pid2label, clothes2label, pid2clothes = self.get_pid2label_and_clothes2label(self.train_dir)\n\n        train, num_train_pids = self._process_dir(self.train_dir, pid2label=pid2label, clothes2label=clothes2label, relabel=True)\n        val_query, num_val_query_pids = self._process_dir(self.val_query_dir, relabel=False)\n        val_gallery, num_val_gallery_pids = self._process_dir(self.val_gallery_dir, relabel=False, recam=len(val_query))\n        test_query, num_test_query_pids = self._process_dir(self.test_query_dir, relabel=False)\n        test_gallery, num_test_gallery_pids = self._process_dir(self.test_gallery_dir, relabel=False, recam=len(test_query))\n\n        num_total_pids = num_train_pids+num_val_gallery_pids+num_test_gallery_pids\n        num_total_imgs = len(train) + len(val_query) + len(val_gallery) + len(test_query) + len(test_gallery)\n\n        logger = logging.getLogger('reid.dataset')\n        logger.info(\"=> LaST loaded\")\n        logger.info(\"Dataset statistics:\")\n        logger.info(\"  --------------------------------------------\")\n        logger.info(\"  subset        | # ids | # images | # clothes\")\n        logger.info(\"  ----------------------------------------\")\n        logger.info(\"  train         | {:5d} | {:8d} | {:9d}\".format(num_train_pids, len(train), len(clothes2label)))\n        logger.info(\"  query(val)    | {:5d} | {:8d} |\".format(num_val_query_pids, len(val_query)))\n        logger.info(\"  gallery(val)  | {:5d} | {:8d} |\".format(num_val_gallery_pids, len(val_gallery)))\n        logger.info(\"  query         | {:5d} | {:8d} |\".format(num_test_query_pids, len(test_query)))\n        logger.info(\"  gallery       | {:5d} | {:8d} |\".format(num_test_gallery_pids, len(test_gallery)))\n        logger.info(\"  --------------------------------------------\")\n        logger.info(\"  total         | {:5d} | {:8d} | \".format(num_total_pids, num_total_imgs))\n        logger.info(\"  --------------------------------------------\")\n\n        self.train = train\n        self.val_query = val_query\n        self.val_gallery = val_gallery\n        self.query = test_query\n        self.gallery = test_gallery\n\n        self.num_train_pids = num_train_pids\n        self.num_train_clothes = len(clothes2label)\n        self.pid2clothes = pid2clothes\n\n    def get_pid2label_and_clothes2label(self, dir_path):\n        img_paths = glob.glob(osp.join(dir_path, '*/*.jpg'))            # [103367,]\n        img_paths.sort()\n\n        pid_container = set()\n        clothes_container = set()\n        for img_path in img_paths:\n            names = osp.basename(img_path).split('.')[0].split('_')\n            clothes = names[0] + '_' + names[-1]\n            pid = int(names[0])\n            pid_container.add(pid)\n            clothes_container.add(clothes)\n        pid_container = sorted(pid_container)\n        clothes_container = sorted(clothes_container)\n        pid2label = {pid: label for label, pid in enumerate(pid_container)}\n        clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)}\n\n        num_pids = len(pid_container)\n        num_clothes = len(clothes_container)\n\n        pid2clothes = np.zeros((num_pids, num_clothes))\n        for img_path in img_paths:\n            names = osp.basename(img_path).split('.')[0].split('_')\n            clothes = names[0] + '_' + names[-1]\n            pid = int(names[0])\n            pid = pid2label[pid]\n            clothes_id = clothes2label[clothes]\n            pid2clothes[pid, clothes_id] = 1\n\n        return pid2label, clothes2label, pid2clothes\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.val_query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.val_query_dir))\n        if not osp.exists(self.val_gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.val_gallery_dir))\n        if not osp.exists(self.test_query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.test_query_dir))\n        if not osp.exists(self.test_gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.test_gallery_dir))\n\n    def _process_dir(self, dir_path, pid2label=None, clothes2label=None, relabel=False, recam=0):\n        if 'query' in dir_path:\n            img_paths = glob.glob(osp.join(dir_path, '*.jpg'))\n        else:\n            img_paths = glob.glob(osp.join(dir_path, '*/*.jpg'))\n        img_paths.sort()\n        \n        dataset = []\n        pid_container = set()\n        for ii, img_path in enumerate(img_paths):\n            names = osp.basename(img_path).split('.')[0].split('_')\n            clothes = names[0] + '_' + names[-1]\n            pid = int(names[0])\n            pid_container.add(pid)\n            camid = int(recam + ii)\n            if relabel and pid2label is not None:\n                pid = pid2label[pid]\n            if relabel and clothes2label is not None:\n                clothes_id = clothes2label[clothes]\n            else:\n                clothes_id = pid\n            dataset.append((img_path, pid, camid, clothes_id))\n        num_pids = len(pid_container)\n\n        return dataset, num_pids"
  },
  {
    "path": "data/datasets/ltcc.py",
    "content": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path as osp\nfrom scipy.io import loadmat\nfrom tools.utils import mkdir_if_missing, write_json, read_json\n\n\nclass LTCC(object):\n    \"\"\" LTCC\n\n    Reference:\n        Qian et al. Long-Term Cloth-Changing Person Re-identification. arXiv:2005.12633, 2020.\n\n    URL: https://naiq.github.io/LTCC_Perosn_ReID.html#\n    \"\"\"\n    dataset_dir = 'LTCC_ReID'\n    def __init__(self, root='data', **kwargs):\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'test')\n        self._check_before_run()\n\n        train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \\\n            self._process_dir_train(self.train_dir)\n        query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = \\\n            self._process_dir_test(self.query_dir, self.gallery_dir)\n        num_total_pids = num_train_pids + num_test_pids\n        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs\n        num_test_imgs = num_query_imgs + num_gallery_imgs \n        num_total_clothes = num_train_clothes + num_test_clothes\n\n        logger = logging.getLogger('reid.dataset')\n        logger.info(\"=> LTCC loaded\")\n        logger.info(\"Dataset statistics:\")\n        logger.info(\"  ----------------------------------------\")\n        logger.info(\"  subset   | # ids | # images | # clothes\")\n        logger.info(\"  ----------------------------------------\")\n        logger.info(\"  train    | {:5d} | {:8d} | {:9d}\".format(num_train_pids, num_train_imgs, num_train_clothes))\n        logger.info(\"  test     | {:5d} | {:8d} | {:9d}\".format(num_test_pids, num_test_imgs, num_test_clothes))\n        logger.info(\"  query    | {:5d} | {:8d} |\".format(num_test_pids, num_query_imgs))\n        logger.info(\"  gallery  | {:5d} | {:8d} |\".format(num_test_pids, num_gallery_imgs))\n        logger.info(\"  ----------------------------------------\")\n        logger.info(\"  total    | {:5d} | {:8d} | {:9d}\".format(num_total_pids, num_total_imgs, num_total_clothes))\n        logger.info(\"  ----------------------------------------\")\n\n        self.train = train\n        self.query = query\n        self.gallery = gallery\n\n        self.num_train_pids = num_train_pids\n        self.num_train_clothes = num_train_clothes\n        self.pid2clothes = pid2clothes\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.query_dir))\n        if not osp.exists(self.gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.gallery_dir))\n\n    def _process_dir_train(self, dir_path):\n        img_paths = glob.glob(osp.join(dir_path, '*.png'))\n        img_paths.sort()\n        pattern1 = re.compile(r'(\\d+)_(\\d+)_c(\\d+)')\n        pattern2 = re.compile(r'(\\w+)_c')\n\n        pid_container = set()\n        clothes_container = set()\n        for img_path in img_paths:\n            pid, _, _ = map(int, pattern1.search(img_path).groups())\n            clothes_id = pattern2.search(img_path).group(1)\n            pid_container.add(pid)\n            clothes_container.add(clothes_id)\n        pid_container = sorted(pid_container)\n        clothes_container = sorted(clothes_container)\n        pid2label = {pid:label for label, pid in enumerate(pid_container)}\n        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}\n\n        num_pids = len(pid_container)\n        num_clothes = len(clothes_container)\n\n        dataset = []\n        pid2clothes = np.zeros((num_pids, num_clothes))\n        for img_path in img_paths:\n            pid, _, camid = map(int, pattern1.search(img_path).groups())\n            clothes = pattern2.search(img_path).group(1)\n            camid -= 1 # index starts from 0\n            pid = pid2label[pid]\n            clothes_id = clothes2label[clothes]\n            dataset.append((img_path, pid, camid, clothes_id))\n            pid2clothes[pid, clothes_id] = 1\n        \n        num_imgs = len(dataset)\n\n        return dataset, num_pids, num_imgs, num_clothes, pid2clothes\n\n    def _process_dir_test(self, query_path, gallery_path):\n        query_img_paths = glob.glob(osp.join(query_path, '*.png'))\n        gallery_img_paths = glob.glob(osp.join(gallery_path, '*.png'))\n        query_img_paths.sort()\n        gallery_img_paths.sort()\n        pattern1 = re.compile(r'(\\d+)_(\\d+)_c(\\d+)')\n        pattern2 = re.compile(r'(\\w+)_c')\n\n        pid_container = set()\n        clothes_container = set()\n        for img_path in query_img_paths:\n            pid, _, _ = map(int, pattern1.search(img_path).groups())\n            clothes_id = pattern2.search(img_path).group(1)\n            pid_container.add(pid)\n            clothes_container.add(clothes_id)\n        for img_path in gallery_img_paths:\n            pid, _, _ = map(int, pattern1.search(img_path).groups())\n            clothes_id = pattern2.search(img_path).group(1)\n            pid_container.add(pid)\n            clothes_container.add(clothes_id)\n        pid_container = sorted(pid_container)\n        clothes_container = sorted(clothes_container)\n        pid2label = {pid:label for label, pid in enumerate(pid_container)}\n        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}\n\n        num_pids = len(pid_container)\n        num_clothes = len(clothes_container)\n\n        query_dataset = []\n        gallery_dataset = []\n        for img_path in query_img_paths:\n            pid, _, camid = map(int, pattern1.search(img_path).groups())\n            clothes_id = pattern2.search(img_path).group(1)\n            camid -= 1 # index starts from 0\n            clothes_id = clothes2label[clothes_id]\n            query_dataset.append((img_path, pid, camid, clothes_id))\n\n        for img_path in gallery_img_paths:\n            pid, _, camid = map(int, pattern1.search(img_path).groups())\n            clothes_id = pattern2.search(img_path).group(1)\n            camid -= 1 # index starts from 0\n            clothes_id = clothes2label[clothes_id]\n            gallery_dataset.append((img_path, pid, camid, clothes_id))\n        \n        num_imgs_query = len(query_dataset)\n        num_imgs_gallery = len(gallery_dataset)\n\n        return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes\n\n"
  },
  {
    "path": "data/datasets/prcc.py",
    "content": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path as osp\nfrom scipy.io import loadmat\nfrom tools.utils import mkdir_if_missing, write_json, read_json\n\n\nclass PRCC(object):\n    \"\"\" PRCC\n\n    Reference:\n        Yang et al. Person Re-identification by Contour Sketch under Moderate Clothing Change. TPAMI, 2019.\n\n    URL: https://drive.google.com/file/d/1yTYawRm4ap3M-j0PjLQJ--xmZHseFDLz/view\n    \"\"\"\n    dataset_dir = 'prcc'\n    def __init__(self, root='data', **kwargs):\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'rgb/train')\n        self.val_dir = osp.join(self.dataset_dir, 'rgb/val')\n        self.test_dir = osp.join(self.dataset_dir, 'rgb/test')\n        self._check_before_run()\n\n        train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \\\n            self._process_dir_train(self.train_dir)\n        val, num_val_pids, num_val_imgs, num_val_clothes, _ = \\\n            self._process_dir_train(self.val_dir)\n\n        query_same, query_diff, gallery, num_test_pids, \\\n            num_query_imgs_same, num_query_imgs_diff, num_gallery_imgs, \\\n            num_test_clothes, gallery_idx = self._process_dir_test(self.test_dir)\n\n        num_total_pids = num_train_pids + num_test_pids\n        num_test_imgs = num_query_imgs_same + num_query_imgs_diff + num_gallery_imgs\n        num_total_imgs = num_train_imgs + num_val_imgs + num_test_imgs\n        num_total_clothes = num_train_clothes + num_test_clothes\n\n        logger = logging.getLogger('reid.dataset')\n        logger.info(\"=> PRCC loaded\")\n        logger.info(\"Dataset statistics:\")\n        logger.info(\"  --------------------------------------------\")\n        logger.info(\"  subset      | # ids | # images | # clothes\")\n        logger.info(\"  --------------------------------------------\")\n        logger.info(\"  train       | {:5d} | {:8d} | {:9d}\".format(num_train_pids, num_train_imgs, num_train_clothes))\n        logger.info(\"  val         | {:5d} | {:8d} | {:9d}\".format(num_val_pids, num_val_imgs, num_val_clothes))\n        logger.info(\"  test        | {:5d} | {:8d} | {:9d}\".format(num_test_pids, num_test_imgs, num_test_clothes))\n        logger.info(\"  query(same) | {:5d} | {:8d} |\".format(num_test_pids, num_query_imgs_same))\n        logger.info(\"  query(diff) | {:5d} | {:8d} |\".format(num_test_pids, num_query_imgs_diff))\n        logger.info(\"  gallery     | {:5d} | {:8d} |\".format(num_test_pids, num_gallery_imgs))\n        logger.info(\"  --------------------------------------------\")\n        logger.info(\"  total       | {:5d} | {:8d} | {:9d}\".format(num_total_pids, num_total_imgs, num_total_clothes))\n        logger.info(\"  --------------------------------------------\")\n\n        self.train = train\n        self.val = val\n        self.query_same = query_same\n        self.query_diff = query_diff\n        self.gallery = gallery\n\n        self.num_train_pids = num_train_pids\n        self.num_train_clothes = num_train_clothes\n        self.pid2clothes = pid2clothes\n        self.gallery_idx = gallery_idx\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.val_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.val_dir))\n        if not osp.exists(self.test_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.test_dir))\n\n    def _process_dir_train(self, dir_path):\n        pdirs = glob.glob(osp.join(dir_path, '*'))\n        pdirs.sort()\n\n        pid_container = set()\n        clothes_container = set()\n        for pdir in pdirs:\n            pid = int(osp.basename(pdir))\n            pid_container.add(pid)\n            img_dirs = glob.glob(osp.join(pdir, '*.jpg'))\n            for img_dir in img_dirs:\n                cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C'\n                if cam in ['A', 'B']:\n                    clothes_container.add(osp.basename(pdir))\n                else:\n                    clothes_container.add(osp.basename(pdir)+osp.basename(img_dir)[0])\n        pid_container = sorted(pid_container)\n        clothes_container = sorted(clothes_container)\n        pid2label = {pid:label for label, pid in enumerate(pid_container)}\n        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}\n        cam2label = {'A': 0, 'B': 1, 'C': 2}\n\n        num_pids = len(pid_container)\n        num_clothes = len(clothes_container)\n\n        dataset = []\n        pid2clothes = np.zeros((num_pids, num_clothes))\n        for pdir in pdirs:\n            pid = int(osp.basename(pdir))\n            img_dirs = glob.glob(osp.join(pdir, '*.jpg'))\n            for img_dir in img_dirs:\n                cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C'\n                label = pid2label[pid]\n                camid = cam2label[cam]\n                if cam in ['A', 'B']:\n                    clothes_id = clothes2label[osp.basename(pdir)]\n                else:\n                    clothes_id = clothes2label[osp.basename(pdir)+osp.basename(img_dir)[0]]\n                dataset.append((img_dir, label, camid, clothes_id))\n                pid2clothes[label, clothes_id] = 1            \n        \n        num_imgs = len(dataset)\n\n        return dataset, num_pids, num_imgs, num_clothes, pid2clothes\n\n    def _process_dir_test(self, test_path):\n        pdirs = glob.glob(osp.join(test_path, '*'))\n        pdirs.sort()\n\n        pid_container = set()\n        for pdir in glob.glob(osp.join(test_path, 'A', '*')):\n            pid = int(osp.basename(pdir))\n            pid_container.add(pid)\n        pid_container = sorted(pid_container)\n        pid2label = {pid:label for label, pid in enumerate(pid_container)}\n        cam2label = {'A': 0, 'B': 1, 'C': 2}\n\n        num_pids = len(pid_container)\n        num_clothes = num_pids * 2\n\n        query_dataset_same_clothes = []\n        query_dataset_diff_clothes = []\n        gallery_dataset = []\n        for cam in ['A', 'B', 'C']:\n            pdirs = glob.glob(osp.join(test_path, cam, '*'))\n            for pdir in pdirs:\n                pid = int(osp.basename(pdir))\n                img_dirs = glob.glob(osp.join(pdir, '*.jpg'))\n                for img_dir in img_dirs:\n                    # pid = pid2label[pid]\n                    camid = cam2label[cam]\n                    if cam == 'A':\n                        clothes_id = pid2label[pid] * 2\n                        gallery_dataset.append((img_dir, pid, camid, clothes_id))\n                    elif cam == 'B':\n                        clothes_id = pid2label[pid] * 2\n                        query_dataset_same_clothes.append((img_dir, pid, camid, clothes_id))\n                    else:\n                        clothes_id = pid2label[pid] * 2 + 1\n                        query_dataset_diff_clothes.append((img_dir, pid, camid, clothes_id))\n\n        pid2imgidx = {}\n        for idx, (img_dir, pid, camid, clothes_id) in enumerate(gallery_dataset):\n            if pid not in pid2imgidx:\n                pid2imgidx[pid] = []\n            pid2imgidx[pid].append(idx)\n\n        # get 10 gallery index to perform single-shot test\n        gallery_idx = {}\n        random.seed(3)\n        for idx in range(0, 10):\n            gallery_idx[idx] = []\n            for pid in pid2imgidx:\n                gallery_idx[idx].append(random.choice(pid2imgidx[pid]))\n                 \n        num_imgs_query_same = len(query_dataset_same_clothes)\n        num_imgs_query_diff = len(query_dataset_diff_clothes)\n        num_imgs_gallery = len(gallery_dataset)\n\n        return query_dataset_same_clothes, query_dataset_diff_clothes, gallery_dataset, \\\n               num_pids, num_imgs_query_same, num_imgs_query_diff, num_imgs_gallery, \\\n               num_clothes, gallery_idx\n"
  },
  {
    "path": "data/datasets/vcclothes.py",
    "content": "import os\nimport re\nimport glob\nimport h5py\nimport random\nimport math\nimport logging\nimport numpy as np\nimport os.path as osp\nfrom scipy.io import loadmat\nfrom tools.utils import mkdir_if_missing, write_json, read_json\n\n\nclass VCClothes(object):\n    \"\"\" VC-Clothes\n\n    Reference:\n        Wang et al. When Person Re-identification Meets Changing Clothes. In CVPR Workshop, 2020.\n\n    URL: https://wanfb.github.io/dataset.html\n    \"\"\"\n    dataset_dir = 'VC-Clothes'\n    def __init__(self, root='data', mode='all', **kwargs):\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'gallery')\n        # 'all' for all cameras; 'sc' for cam2&3; 'cc' for cam3&4\n        self.mode = mode \n        self._check_before_run()\n\n        train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = self._process_dir_train()\n        query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = self._process_dir_test()\n        num_total_pids = num_train_pids + num_test_pids\n        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs\n        num_test_imgs = num_query_imgs + num_gallery_imgs \n        num_total_clothes = num_train_clothes + num_test_clothes\n\n        logger = logging.getLogger('reid.dataset')\n        logger.info(\"=> VC-Clothes loaded\")\n        logger.info(\"Dataset statistics:\")\n        logger.info(\"  ----------------------------------------\")\n        logger.info(\"  subset   | # ids | # images | # clothes\")\n        logger.info(\"  ----------------------------------------\")\n        logger.info(\"  train    | {:5d} | {:8d} | {:9d}\".format(num_train_pids, num_train_imgs, num_train_clothes))\n        logger.info(\"  test     | {:5d} | {:8d} | {:9d}\".format(num_test_pids, num_test_imgs, num_test_clothes))\n        logger.info(\"  query    | {:5d} | {:8d} |\".format(num_test_pids, num_query_imgs))\n        logger.info(\"  gallery  | {:5d} | {:8d} |\".format(num_test_pids, num_gallery_imgs))\n        logger.info(\"  ----------------------------------------\")\n        logger.info(\"  total    | {:5d} | {:8d} | {:9d}\".format(num_total_pids, num_total_imgs, num_total_clothes))\n        logger.info(\"  ----------------------------------------\")\n\n        self.train = train\n        self.query = query\n        self.gallery = gallery\n\n        self.num_train_pids = num_train_pids\n        self.num_train_clothes = num_train_clothes\n        self.pid2clothes = pid2clothes\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.query_dir))\n        if not osp.exists(self.gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.gallery_dir))\n\n    def _process_dir_train(self):\n        img_paths = glob.glob(osp.join(self.train_dir, '*.jpg'))\n        img_paths.sort()\n        pattern = re.compile(r'(\\d+)-(\\d+)-(\\d+)-(\\d+)')\n\n        pid_container = set()\n        clothes_container = set()\n        for img_path in img_paths:\n            pid, camid, clothes, _ = pattern.search(img_path).groups()\n            clothes_id = pid + clothes\n            pid, camid = int(pid), int(camid)\n            pid_container.add(pid)\n            clothes_container.add(clothes_id)\n        pid_container = sorted(pid_container)\n        clothes_container = sorted(clothes_container)\n        pid2label = {pid:label for label, pid in enumerate(pid_container)}\n        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}\n\n        num_pids = len(pid_container)\n        num_clothes = len(clothes_container)\n\n        dataset = []\n        pid2clothes = np.zeros((num_pids, num_clothes))\n        for img_path in img_paths:\n            pid, camid, clothes, _ = pattern.search(img_path).groups()\n            clothes_id = pid + clothes\n            pid, camid = int(pid), int(camid)\n            camid -= 1 # index starts from 0\n            pid = pid2label[pid]\n            clothes_id = clothes2label[clothes_id]\n            dataset.append((img_path, pid, camid, clothes_id))\n            pid2clothes[pid, clothes_id] = 1\n        \n        num_imgs = len(dataset)\n\n        return dataset, num_pids, num_imgs, num_clothes, pid2clothes\n\n    def _process_dir_test(self):\n        query_img_paths = glob.glob(osp.join(self.query_dir, '*.jpg'))\n        gallery_img_paths = glob.glob(osp.join(self.gallery_dir, '*.jpg'))\n        query_img_paths.sort()\n        gallery_img_paths.sort()\n        pattern = re.compile(r'(\\d+)-(\\d+)-(\\d+)-(\\d+)')\n\n        pid_container = set()\n        clothes_container = set()\n        for img_path in query_img_paths:\n            pid, camid, clothes, _ = pattern.search(img_path).groups()\n            clothes_id = pid + clothes\n            pid, camid = int(pid), int(camid)\n            if self.mode == 'sc' and camid not in [2, 3]:\n                continue\n            if self.mode == 'cc' and camid not in [3, 4]:\n                continue\n            pid_container.add(pid)\n            clothes_container.add(clothes_id)\n        for img_path in gallery_img_paths:\n            pid, camid, clothes, _ = pattern.search(img_path).groups()\n            clothes_id = pid + clothes\n            pid, camid = int(pid), int(camid)\n            if self.mode == 'sc' and camid not in [2, 3]:\n                continue\n            if self.mode == 'cc' and camid not in [3, 4]:\n                continue\n            pid_container.add(pid)\n            clothes_container.add(clothes_id)\n        pid_container = sorted(pid_container)\n        clothes_container = sorted(clothes_container)\n        pid2label = {pid:label for label, pid in enumerate(pid_container)}\n        clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)}\n\n        num_pids = len(pid_container)\n        num_clothes = len(clothes_container)\n\n        query_dataset = []\n        gallery_dataset = []\n        for img_path in query_img_paths:\n            pid, camid, clothes, _ = pattern.search(img_path).groups()\n            clothes_id = pid + clothes\n            pid, camid = int(pid), int(camid)\n            if self.mode == 'sc' and camid not in [2, 3]:\n                continue\n            if self.mode == 'cc' and camid not in [3, 4]:\n                continue\n            camid -= 1 # index starts from 0\n            clothes_id = clothes2label[clothes_id]\n            query_dataset.append((img_path, pid, camid, clothes_id))\n\n        for img_path in gallery_img_paths:\n            pid, camid, clothes, _ = pattern.search(img_path).groups()\n            clothes_id = pid + clothes\n            pid, camid = int(pid), int(camid)\n            if self.mode == 'sc' and camid not in [2, 3]:\n                continue\n            if self.mode == 'cc' and camid not in [3, 4]:\n                continue\n            camid -= 1 # index starts from 0\n            clothes_id = clothes2label[clothes_id]\n            gallery_dataset.append((img_path, pid, camid, clothes_id))\n        \n        num_imgs_query = len(query_dataset)\n        num_imgs_gallery = len(gallery_dataset)\n\n        return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes\n\n\ndef VCClothesSameClothes(root='data', **kwargs):\n    return VCClothes(root=root, mode='sc')\n\n\ndef VCClothesClothesChanging(root='data', **kwargs):\n    return VCClothes(root=root, mode='cc')\n"
  },
  {
    "path": "data/img_transforms.py",
    "content": "from torchvision.transforms import *\nfrom PIL import Image\nimport random\nimport math\n\n\nclass ResizeWithEqualScale(object):\n    \"\"\"\n    Resize an image with equal scale as the original image.\n\n    Args:\n        height (int): resized height.\n        width (int): resized width.\n        interpolation: interpolation manner.\n        fill_color (tuple): color for padding.\n    \"\"\"\n    def __init__(self, height, width, interpolation=Image.BILINEAR, fill_color=(0,0,0)):\n        self.height = height\n        self.width = width\n        self.interpolation = interpolation\n        self.fill_color = fill_color\n\n    def __call__(self, img):\n        width, height = img.size\n        if self.height / self.width >= height / width:\n            height = int(self.width * (height / width))\n            width = self.width\n        else:\n            width = int(self.height * (width / height))\n            height = self.height \n\n        resized_img = img.resize((width, height), self.interpolation)\n        new_img = Image.new('RGB', (self.width, self.height), self.fill_color)\n        new_img.paste(resized_img, (int((self.width - width) / 2), int((self.height - height) / 2)))\n\n        return new_img\n\n\nclass RandomCroping(object):\n    \"\"\"\n    With a probability, first increase image size to (1 + 1/8), and then perform random crop.\n\n    Args:\n        p (float): probability of performing this transformation. Default: 0.5.\n    \"\"\"\n    def __init__(self, p=0.5, interpolation=Image.BILINEAR):\n        self.p = p\n        self.interpolation = interpolation\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (PIL Image): Image to be cropped.\n\n        Returns:\n            PIL Image: Cropped image.\n        \"\"\"\n        width, height = img.size\n        if random.uniform(0, 1) >= self.p:\n            return img\n        \n        new_width, new_height = int(round(width * 1.125)), int(round(height * 1.125))\n        resized_img = img.resize((new_width, new_height), self.interpolation)\n        x_maxrange = new_width - width\n        y_maxrange = new_height - height\n        x1 = int(round(random.uniform(0, x_maxrange)))\n        y1 = int(round(random.uniform(0, y_maxrange)))\n        croped_img = resized_img.crop((x1, y1, x1 + width, y1 + height))\n\n        return croped_img\n\n\nclass RandomErasing(object):\n    \"\"\" \n    Randomly selects a rectangle region in an image and erases its pixels.\n\n    Reference:\n        Zhong et al. Random Erasing Data Augmentation. arxiv: 1708.04896, 2017.\n\n    Args:\n        probability: The probability that the Random Erasing operation will be performed.\n        sl: Minimum proportion of erased area against input image.\n        sh: Maximum proportion of erased area against input image.\n        r1: Minimum aspect ratio of erased area.\n        mean: Erasing value. \n    \"\"\"\n    \n    def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):\n        self.probability = probability\n        self.mean = mean\n        self.sl = sl\n        self.sh = sh\n        self.r1 = r1\n       \n    def __call__(self, img):\n\n        if random.uniform(0, 1) >= self.probability:\n            return img\n\n        for attempt in range(100):\n            area = img.size()[1] * img.size()[2]\n       \n            target_area = random.uniform(self.sl, self.sh) * area\n            aspect_ratio = random.uniform(self.r1, 1/self.r1)\n\n            h = int(round(math.sqrt(target_area * aspect_ratio)))\n            w = int(round(math.sqrt(target_area / aspect_ratio)))\n\n            if w < img.size()[2] and h < img.size()[1]:\n                x1 = random.randint(0, img.size()[1] - h)\n                y1 = random.randint(0, img.size()[2] - w)\n                if img.size()[0] == 3:\n                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]\n                    img[1, x1:x1+h, y1:y1+w] = self.mean[1]\n                    img[2, x1:x1+h, y1:y1+w] = self.mean[2]\n                else:\n                    img[0, x1:x1+h, y1:y1+w] = self.mean[0]\n                return img\n\n        return img"
  },
  {
    "path": "data/samplers.py",
    "content": "import copy\nimport math\nimport random\nimport numpy as np\nfrom torch import distributed as dist\nfrom collections import defaultdict\nfrom torch.utils.data.sampler import Sampler\n\n\nclass RandomIdentitySampler(Sampler):\n    \"\"\"\n    Randomly sample N identities, then for each identity,\n    randomly sample K instances, therefore batch size is N*K.\n\n    Args:\n        data_source (Dataset): dataset to sample from.\n        num_instances (int): number of instances per identity.\n    \"\"\"\n    def __init__(self, data_source, num_instances=4):\n        self.data_source = data_source\n        self.num_instances = num_instances\n        self.index_dic = defaultdict(list)\n        for index, (_, pid, _, _) in enumerate(data_source):\n            self.index_dic[pid].append(index)\n        self.pids = list(self.index_dic.keys())\n        self.num_identities = len(self.pids)\n\n        # compute number of examples in an epoch\n        self.length = 0\n        for pid in self.pids:\n            idxs = self.index_dic[pid]\n            num = len(idxs)\n            if num < self.num_instances:\n                num = self.num_instances\n            self.length += num - num % self.num_instances\n\n    def __iter__(self):\n        list_container = []\n\n        for pid in self.pids:\n            idxs = copy.deepcopy(self.index_dic[pid])\n            if len(idxs) < self.num_instances:\n                idxs = np.random.choice(idxs, size=self.num_instances, replace=True)\n            random.shuffle(idxs)\n            batch_idxs = []\n            for idx in idxs:\n                batch_idxs.append(idx)\n                if len(batch_idxs) == self.num_instances:\n                    list_container.append(batch_idxs)\n                    batch_idxs = []\n\n        random.shuffle(list_container)\n\n        ret = []\n        for batch_idxs in list_container:\n            ret.extend(batch_idxs)\n\n        return iter(ret)\n\n    def __len__(self):\n        return self.length\n\n\nclass DistributedRandomIdentitySampler(Sampler):\n    \"\"\"\n    Randomly sample N identities, then for each identity,\n    randomly sample K instances, therefore batch size is N*K.\n\n    Args:\n    - data_source (Dataset): dataset to sample from.\n    - num_instances (int): number of instances per identity.\n    - num_replicas (int, optional): Number of processes participating in\n        distributed training. By default, :attr:`world_size` is retrieved from the\n        current distributed group.\n    - rank (int, optional): Rank of the current process within :attr:`num_replicas`.\n        By default, :attr:`rank` is retrieved from the current distributed group.\n    - seed (int, optional): random seed used to shuffle the sampler. \n        This number should be identical across all\n        processes in the distributed group. Default: ``0``.\n    \"\"\"\n    def __init__(self, data_source, num_instances=4, \n                 num_replicas=None, rank=None, seed=0):\n        if num_replicas is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            num_replicas = dist.get_world_size()\n        if rank is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            rank = dist.get_rank()\n        if rank >= num_replicas or rank < 0:\n            raise ValueError(\n                \"Invalid rank {}, rank should be in the interval\"\n                \" [0, {}]\".format(rank, num_replicas - 1))\n        self.num_replicas = num_replicas\n        self.rank = rank\n        self.seed = seed\n        self.epoch = 0\n\n        self.data_source = data_source\n        self.num_instances = num_instances\n        self.index_dic = defaultdict(list)\n        for index, (_, pid, _, _) in enumerate(data_source):\n            self.index_dic[pid].append(index)\n        self.pids = list(self.index_dic.keys())\n        self.num_identities = len(self.pids)\n\n        # compute number of examples in an epoch\n        self.length = 0\n        for pid in self.pids:\n            idxs = self.index_dic[pid]\n            num = len(idxs)\n            if num < self.num_instances:\n                num = self.num_instances\n            self.length += num - num % self.num_instances\n        assert self.length % self.num_instances == 0\n\n        if self.length // self.num_instances % self.num_replicas != 0: \n            self.num_samples = math.ceil((self.length // self.num_instances - self.num_replicas) / self.num_replicas) * self.num_instances\n        else:\n            self.num_samples = math.ceil(self.length / self.num_replicas) \n        self.total_size = self.num_samples * self.num_replicas\n\n    def __iter__(self):\n        # deterministically shuffle based on epoch and seed\n        random.seed(self.seed + self.epoch)\n        np.random.seed(self.seed + self.epoch)\n\n        list_container = []\n        for pid in self.pids:\n            idxs = copy.deepcopy(self.index_dic[pid])\n            if len(idxs) < self.num_instances:\n                idxs = np.random.choice(idxs, size=self.num_instances, replace=True)\n            random.shuffle(idxs)\n            batch_idxs = []\n            for idx in idxs:\n                batch_idxs.append(idx)\n                if len(batch_idxs) == self.num_instances:\n                    list_container.append(batch_idxs)\n                    batch_idxs = []\n        random.shuffle(list_container)\n\n        # remove tail of data to make it evenly divisible.\n        list_container = list_container[:self.total_size//self.num_instances]\n        assert len(list_container) == self.total_size//self.num_instances\n\n        # subsample\n        list_container = list_container[self.rank:self.total_size//self.num_instances:self.num_replicas]\n        assert len(list_container) == self.num_samples//self.num_instances\n\n        ret = []\n        for batch_idxs in list_container:\n            ret.extend(batch_idxs)\n\n        return iter(ret)\n\n    def __len__(self):\n        return self.num_samples\n\n    def set_epoch(self, epoch):\n        \"\"\"\n        Sets the epoch for this sampler. This ensures all replicas\n        use a different random ordering for each epoch. Otherwise, the next iteration of this\n        sampler will yield the same ordering.\n\n        Args:\n            epoch (int): Epoch number.\n        \"\"\"\n        self.epoch = epoch\n\n\nclass DistributedInferenceSampler(Sampler):\n    \"\"\"\n    refer to: https://github.com/huggingface/transformers/blob/447808c85f0e6d6b0aeeb07214942bf1e578f9d2/src/transformers/trainer_pt_utils.py\n\n    Distributed Sampler that subsamples indicies sequentially,\n    making it easier to collate all results at the end.\n    Even though we only use this sampler for eval and predict (no training),\n    which means that the model params won't have to be synced (i.e. will not hang\n    for synchronization even if varied number of forward passes), we still add extra\n    samples to the sampler to make it evenly divisible (like in `DistributedSampler`)\n    to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.\n    \"\"\"\n    def __init__(self, dataset, rank=None, num_replicas=None):\n        if num_replicas is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            num_replicas = dist.get_world_size()\n        if rank is None:\n            if not dist.is_available():\n                raise RuntimeError(\"Requires distributed package to be available\")\n            rank = dist.get_rank()\n        self.dataset = dataset\n        self.num_replicas = num_replicas\n        self.rank = rank\n\n        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))\n        self.total_size = self.num_samples * self.num_replicas\n\n    def __iter__(self):\n        indices = list(range(len(self.dataset)))\n        # add extra samples to make it evenly divisible\n        indices += [indices[-1]] * (self.total_size - len(indices))\n        # subsample\n        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]\n        return iter(indices)\n\n    def __len__(self):\n        return self.num_samples"
  },
  {
    "path": "data/spatial_transforms.py",
    "content": "import random\nimport math\nimport numbers\nimport collections\nimport numpy as np\nimport torch\nimport torchvision.transforms as T\nfrom PIL import Image, ImageOps\ntry:\n    import accimage\nexcept ImportError:\n    accimage = None\n\n\nclass Compose(object):\n    \"\"\"Composes several transforms together.\n\n    Args:\n        transforms (list of ``Transform`` objects): list of transforms to compose.\n\n    Example:\n        >>> transforms.Compose([\n        >>>     transforms.CenterCrop(10),\n        >>>     transforms.ToTensor(),\n        >>> ])\n    \"\"\"\n\n    def __init__(self, transforms):\n        self.transforms = transforms\n\n    def __call__(self, img):\n        for t in self.transforms:\n            img = t(img)\n        return img\n\n    def randomize_parameters(self):\n        for t in self.transforms:\n            t.randomize_parameters()\n\n\nclass ToTensor(object):\n    \"\"\"Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor.\n    Converts a PIL.Image or numpy.ndarray (H x W x C) in the range\n    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].\n    \"\"\"\n\n    def __init__(self, norm_value=255):\n        self.norm_value = norm_value\n\n    def __call__(self, pic):\n        \"\"\"\n        Args:\n            pic (PIL.Image or numpy.ndarray): Image to be converted to tensor.\n        Returns:\n            Tensor: Converted image.\n        \"\"\"\n        if isinstance(pic, np.ndarray):\n            # handle numpy array\n            img = torch.from_numpy(pic.transpose((2, 0, 1)))\n            # backward compatibility\n            return img.float().div(self.norm_value)\n\n        if accimage is not None and isinstance(pic, accimage.Image):\n            nppic = np.zeros(\n                [pic.channels, pic.height, pic.width], dtype=np.float32)\n            pic.copyto(nppic)\n            return torch.from_numpy(nppic)\n\n        # handle PIL Image\n        if pic.mode == 'I':\n            img = torch.from_numpy(np.array(pic, np.int32, copy=False))\n        elif pic.mode == 'I;16':\n            img = torch.from_numpy(np.array(pic, np.int16, copy=False))\n        else:\n            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))\n        # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK\n        if pic.mode == 'YCbCr':\n            nchannel = 3\n        elif pic.mode == 'I;16':\n            nchannel = 1\n        else:\n            nchannel = len(pic.mode)\n        img = img.view(pic.size[1], pic.size[0], nchannel)\n        # put it from HWC to CHW format\n        # yikes, this transpose takes 80% of the loading time/CPU\n        img = img.transpose(0, 1).transpose(0, 2).contiguous()\n        if isinstance(img, torch.ByteTensor):\n            return img.float().div(self.norm_value)\n        else:\n            return img\n\n    def randomize_parameters(self):\n        pass\n\n\nclass Normalize(object):\n    \"\"\"Normalize an tensor image with mean and standard deviation.\n    Given mean: (R, G, B) and std: (R, G, B),\n    will normalize each channel of the torch.*Tensor, i.e.\n    channel = (channel - mean) / std\n\n    Args:\n        mean (sequence): Sequence of means for R, G, B channels respecitvely.\n        std (sequence): Sequence of standard deviations for R, G, B channels\n            respecitvely.\n    \"\"\"\n\n    def __init__(self, mean, std):\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, tensor):\n        \"\"\"\n        Args:\n            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.\n        Returns:\n            Tensor: Normalized image.\n        \"\"\"\n        # TODO: make efficient\n        for t, m, s in zip(tensor, self.mean, self.std):\n            t.sub_(m).div_(s)\n        return tensor\n\n    def randomize_parameters(self):\n        pass\n\n\nclass Scale(object):\n    \"\"\"Rescale the input PIL.Image to the given size.\n\n    Args:\n        size (sequence or int): Desired output size. If size is a sequence like\n            (w, h), output size will be matched to this. If size is an int,\n            smaller edge of the image will be matched to this number.\n            i.e, if height > width, then image will be rescaled to\n            (size * height / width, size)\n        interpolation (int, optional): Desired interpolation. Default is\n            ``PIL.Image.BILINEAR``\n    \"\"\"\n\n    def __init__(self, size, interpolation=Image.BILINEAR):\n        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)\n        self.size = size\n        self.interpolation = interpolation\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (PIL.Image): Image to be scaled.\n        Returns:\n            PIL.Image: Rescaled image.\n        \"\"\"\n        if isinstance(self.size, int):\n            w, h = img.size\n            if (w <= h and w == self.size) or (h <= w and h == self.size):\n                return img\n            if w < h:\n                ow = self.size\n                oh = int(self.size * h / w)\n                return img.resize((ow, oh), self.interpolation)\n            else:\n                oh = self.size\n                ow = int(self.size * w / h)\n                return img.resize((ow, oh), self.interpolation)\n        else:\n            return img.resize(self.size[::-1], self.interpolation)\n\n    def randomize_parameters(self):\n        pass\n\n\nclass RandomHorizontalFlip(object):\n    \"\"\"Horizontally flip the given PIL.Image randomly with a probability of 0.5.\"\"\"\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (PIL.Image): Image to be flipped.\n        Returns:\n            PIL.Image: Randomly flipped image.\n        \"\"\"\n        if self.p < 0.5:\n            return img.transpose(Image.FLIP_LEFT_RIGHT)\n        return img\n\n    def randomize_parameters(self):\n        self.p = random.random()\n\n\nclass RandomCrop(object):\n    \"\"\"\n    With a probability, first increase image size to (1 + 1/8), and then perform random crop.\n\n    Args:\n        height (int): target height.\n        width (int): target width.\n        p (float): probability of performing this transformation. Default: 0.5.\n    \"\"\"\n    def __init__(self, size, p=0.5, interpolation=Image.BILINEAR):\n        if isinstance(size, numbers.Number):\n            self.size = (int(size), int(size))\n        else:\n            self.size = size\n\n        self.height, self.width = self.size\n        self.p = p\n        self.interpolation = interpolation\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (PIL Image): Image to be cropped.\n\n        Returns:\n            PIL Image: Cropped image.\n        \"\"\"\n        if not self.cropping:\n            return img.resize((self.width, self.height), self.interpolation)\n        \n        new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125))\n        resized_img = img.resize((new_width, new_height), self.interpolation)\n        x_maxrange = new_width - self.width\n        y_maxrange = new_height - self.height\n        x1 = int(round(self.tl_x * x_maxrange))\n        y1 = int(round(self.tl_y * y_maxrange))\n        return resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))\n\n    def randomize_parameters(self):\n        self.cropping = random.uniform(0, 1) < self.p\n        self.tl_x = random.random()\n        self.tl_y = random.random()\n\n\nclass RandomErasing(object):\n    \"\"\" \n    Randomly selects a rectangle region in an image and erases its pixels.\n\n    Reference:\n        Zhong et al. Random Erasing Data Augmentation. arxiv: 1708.04896, 2017.\n        \n    Args:\n         probability: The probability that the Random Erasing operation will be performed.\n         sl: Minimum proportion of erased area against input image.\n         sh: Maximum proportion of erased area against input image.\n         r1: Minimum aspect ratio of erased area.\n         mean: Erasing value. \n    \"\"\"\n    \n    def __init__(self, height=256, width=128, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.485, 0.456, 0.406]):\n        self.probability = probability\n        self.mean = mean\n        self.sl = sl\n        self.sh = sh\n        self.r1 = r1\n        self.height = height\n        self.width = width\n       \n    def __call__(self, img):\n        if self.re:\n            return img\n\n        if img.size()[0] == 3:\n            img[0, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[0]\n            img[1, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[1]\n            img[2, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[2]\n        else:\n            img[0, self.x1:self.x1+self.h, self.y1:self.y1+self.w] = self.mean[0]\n        return img\n\n    def randomize_parameters(self):\n        self.re = random.uniform(0, 1) < self.probability\n        self.h, self.w, self.x1, self.y1 = 0, 0, 0, 0\n        whether_re = False\n        if self.re:\n            for attempt in range(100):\n                area = self.height*self.width\n\n                target_area = random.uniform(self.sl, self.sh) * area\n                aspect_ratio = random.uniform(self.r1, 1/self.r1)\n\n                self.h = int(round(math.sqrt(target_area * aspect_ratio)))\n                self.w = int(round(math.sqrt(target_area / aspect_ratio)))\n                if self.w < self.width and self.h < self.height:\n                    self.x1 = random.randint(0, self.height - self.h)\n                    self.y1 = random.randint(0, self.width - self.w)\n                    whether_re = True\n                    break\n\n        self.re = whether_re"
  },
  {
    "path": "data/temporal_transforms.py",
    "content": "import random\nimport numpy as np\n\n\nclass TemporalRandomCrop(object):\n    \"\"\"Temporally crop the given frame indices at a random location.\n\n    If the number of frames is less than the size,\n    loop the indices as many times as necessary to satisfy the size.\n\n    Args:\n        size (int): Desired output size of the crop.\n        stride (int): Temporal sampling stride\n    \"\"\"\n\n    def __init__(self, size=4, stride=8):\n        self.size = size\n        self.stride = stride\n\n    def __call__(self, frame_indices):\n        \"\"\"\n        Args:\n            frame_indices (list): frame indices to be cropped.\n        Returns:\n            list: Cropped frame indices.\n        \"\"\"\n        frame_indices = list(frame_indices)\n\n        if len(frame_indices) >= self.size * self.stride:\n            rand_end = len(frame_indices) - (self.size - 1) * self.stride - 1\n            begin_index = random.randint(0, rand_end)\n            end_index = begin_index + (self.size - 1) * self.stride + 1\n            out = frame_indices[begin_index:end_index:self.stride]\n        elif len(frame_indices) >= self.size:\n            clips = []\n            for i in range(self.size):\n                    clips.append(frame_indices[len(frame_indices)//self.size*i : len(frame_indices)//self.size*(i+1)])\n            out = []\n            for i in range(self.size):\n                out.append(random.choice(clips[i]))\n        else:\n            index = np.random.choice(len(frame_indices), size=self.size, replace=True)\n            index.sort()\n            out = [frame_indices[index[i]] for i in range(self.size)]\n\n        return out\n\n\nclass TemporalBeginCrop(object):\n    \"\"\"Temporally crop the given frame indices at a beginning.\n\n    If the number of frames is less than the size,\n    loop the indices as many times as necessary to satisfy the size.\n\n    Args:\n        size (int): Desired output size of the crop.\n        stride (int): Temporal sampling stride\n    \"\"\"\n\n    def __init__(self, size=8, stride=4):\n        self.size = size\n        self.stride = stride\n        \n    def __call__(self, frame_indices):\n        frame_indices = list(frame_indices)\n\n        if len(frame_indices) >= self.size * self.stride:\n            out = frame_indices[0 : self.size * self.stride : self.stride]\n        else:\n            out = frame_indices[0 : self.size]\n            while len(out) < self.size:\n                for index in out:\n                    if len(out) >= self.size:\n                        break\n                    out.append(index)\n\n        return out\n\n\nclass TemporalDivisionCrop(object):\n    \"\"\"Temporally crop the given frame indices by TSN.\n\n    Args:\n        size (int): Desired output size of the crop.\n    \"\"\"\n    def __init__(self, size=4):\n        self.size = size\n\n    def __call__(self, frame_indices):\n        \"\"\"\n        Args:\n            frame_indices (list): frame indices to be cropped.\n        Returns:\n            list: Cropped frame indices.\n        \"\"\"\n        frame_indices = list(frame_indices)\n\n        if len(frame_indices) >= self.size:\n            clips = []\n            for i in range(self.size):\n                clips.append(frame_indices[len(frame_indices)//self.size*i : len(frame_indices)//self.size*(i+1)])\n            out = []\n            for i in range(self.size):\n                out.append(random.choice(clips[i]))\n        else:\n            index = np.random.choice(len(frame_indices), size=self.size, replace=True)\n            index.sort()\n            out = [frame_indices[index[i]] for i in range(self.size)]\n\n        return out\n"
  },
  {
    "path": "losses/__init__.py",
    "content": "from torch import nn\nfrom losses.cross_entropy_loss_with_label_smooth import CrossEntropyWithLabelSmooth\nfrom losses.triplet_loss import TripletLoss\nfrom losses.contrastive_loss import ContrastiveLoss\nfrom losses.arcface_loss import ArcFaceLoss\nfrom losses.cosface_loss import CosFaceLoss, PairwiseCosFaceLoss\nfrom losses.circle_loss import CircleLoss, PairwiseCircleLoss\nfrom losses.clothes_based_adversarial_loss import ClothesBasedAdversarialLoss, ClothesBasedAdversarialLossWithMemoryBank\n\n\ndef build_losses(config, num_train_clothes):\n    # Build identity classification loss\n    if config.LOSS.CLA_LOSS == 'crossentropy':\n        criterion_cla = nn.CrossEntropyLoss()\n    elif config.LOSS.CLA_LOSS == 'crossentropylabelsmooth':\n        criterion_cla = CrossEntropyWithLabelSmooth()\n    elif config.LOSS.CLA_LOSS == 'arcface':\n        criterion_cla = ArcFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)\n    elif config.LOSS.CLA_LOSS == 'cosface':\n        criterion_cla = CosFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)\n    elif config.LOSS.CLA_LOSS == 'circle':\n        criterion_cla = CircleLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M)\n    else:\n        raise KeyError(\"Invalid classification loss: '{}'\".format(config.LOSS.CLA_LOSS))\n\n    # Build pairwise loss\n    if config.LOSS.PAIR_LOSS == 'triplet':\n        criterion_pair = TripletLoss(margin=config.LOSS.PAIR_M)\n    elif config.LOSS.PAIR_LOSS == 'contrastive':\n        criterion_pair = ContrastiveLoss(scale=config.LOSS.PAIR_S)\n    elif config.LOSS.PAIR_LOSS == 'cosface':\n        criterion_pair = PairwiseCosFaceLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M)\n    elif config.LOSS.PAIR_LOSS == 'circle':\n        criterion_pair = PairwiseCircleLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M)\n    else:\n        raise KeyError(\"Invalid pairwise loss: '{}'\".format(config.LOSS.PAIR_LOSS))\n\n    # Build clothes classification loss\n    if config.LOSS.CLOTHES_CLA_LOSS == 'crossentropy':\n        criterion_clothes = nn.CrossEntropyLoss()\n    elif config.LOSS.CLOTHES_CLA_LOSS == 'cosface':\n        criterion_clothes = CosFaceLoss(scale=config.LOSS.CLA_S, margin=0)\n    else:\n        raise KeyError(\"Invalid clothes classification loss: '{}'\".format(config.LOSS.CLOTHES_CLA_LOSS))\n\n    # Build clothes-based adversarial loss\n    if config.LOSS.CAL == 'cal':\n        criterion_cal = ClothesBasedAdversarialLoss(scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON)\n    elif config.LOSS.CAL == 'calwithmemory':\n        criterion_cal = ClothesBasedAdversarialLossWithMemoryBank(num_clothes=num_train_clothes, feat_dim=config.MODEL.FEATURE_DIM,\n                             momentum=config.LOSS.MOMENTUM, scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON)\n    else:\n        raise KeyError(\"Invalid clothing classification loss: '{}'\".format(config.LOSS.CAL))\n\n    return criterion_cla, criterion_pair, criterion_clothes, criterion_cal\n"
  },
  {
    "path": "losses/arcface_loss.py",
    "content": "import math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass ArcFaceLoss(nn.Module):\n    \"\"\" ArcFace loss.\n\n    Reference:\n        Deng et al. ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019.\n\n    Args:\n        scale (float): scaling factor.\n        margin (float): pre-defined margin.\n    \"\"\"\n    def __init__(self, scale=16, margin=0.1):\n        super().__init__()\n        self.s = scale\n        self.m = margin\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        Args:\n            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)\n            targets: ground truth labels with shape (batch_size)\n        \"\"\"\n        # get a one-hot index\n        index = inputs.data * 0.0 \n        index.scatter_(1, targets.data.view(-1, 1), 1)\n        index = index.bool()\n\n        cos_m = math.cos(self.m)\n        sin_m = math.sin(self.m)\n        cos_t = inputs[index]\n        sin_t = torch.sqrt(1.0 - cos_t * cos_t)\n        cos_t_add_m = cos_t * cos_m  - sin_t * sin_m\n\n        cond_v = cos_t - math.cos(math.pi - self.m)\n        cond = F.relu(cond_v)\n        keep = cos_t - math.sin(math.pi - self.m) * self.m\n\n        cos_t_add_m = torch.where(cond.bool(), cos_t_add_m, keep)\n\n        output = inputs * 1.0 \n        output[index] = cos_t_add_m\n        output = self.s * output\n\n        return F.cross_entropy(output, targets)\n"
  },
  {
    "path": "losses/circle_loss.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch import distributed as dist\nfrom losses.gather import GatherLayer\n\n\nclass CircleLoss(nn.Module):\n    \"\"\" Circle Loss based on the predictions of classifier.\n\n    Reference:\n        Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020.\n\n    Args:\n        scale (float): scaling factor.\n        margin (float): pre-defined margin.\n    \"\"\"\n    def __init__(self, scale=96, margin=0.3, **kwargs):\n        super().__init__()\n        self.s = scale\n        self.m = margin\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        Args:\n            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)\n            targets: ground truth labels with shape (batch_size)\n        \"\"\"\n        mask = torch.zeros_like(inputs).cuda()\n        mask.scatter_(1, targets.view(-1, 1), 1.0)\n    \n        pos_scale = self.s * F.relu(1 + self.m - inputs.detach())\n        neg_scale = self.s * F.relu(inputs.detach() + self.m)\n        scale_matrix = pos_scale * mask + neg_scale * (1 - mask)\n\n        scores = (inputs - (1 - self.m) * mask - self.m * (1 - mask)) * scale_matrix\n        \n        loss = F.cross_entropy(scores, targets)\n\n        return loss\n\n\nclass PairwiseCircleLoss(nn.Module):\n    \"\"\" Circle Loss among sample pairs.\n\n    Reference:\n        Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020.\n\n    Args:\n        scale (float): scaling factor.\n        margin (float): pre-defined margin.\n    \"\"\"\n    def __init__(self, scale=48, margin=0.35, **kwargs):\n        super().__init__()\n        self.s = scale\n        self.m = margin\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        Args:\n            inputs: sample features (before classifier) with shape (batch_size, feat_dim)\n            targets: ground truth labels with shape (batch_size)\n        \"\"\"\n        # l2-normalize\n        inputs = F.normalize(inputs, p=2, dim=1)\n\n        # gather all samples from different GPUs as gallery to compute pairwise loss.\n        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)\n        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)\n        m, n = targets.size(0), gallery_targets.size(0)\n\n        # compute cosine similarity\n        similarities = torch.matmul(inputs, gallery_inputs.t())\n        \n        # get mask for pos/neg pairs\n        targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1)\n        mask = torch.eq(targets, gallery_targets.T).float().cuda()\n        mask_self = torch.zeros_like(mask)\n        rank = dist.get_rank()\n        mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda()\n        mask_pos = mask - mask_self\n        mask_neg = 1 - mask\n\n        pos_scale = self.s * F.relu(1 + self.m - similarities.detach())\n        neg_scale = self.s * F.relu(similarities.detach() + self.m)\n        scale_matrix = pos_scale * mask_pos + neg_scale * mask_neg\n\n        scores = (similarities - self.m) * mask_neg + (1 - self.m - similarities) * mask_pos\n        scores = scores * scale_matrix\n        \n        neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1)\n        pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1)\n\n        loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean()\n\n        return loss\n"
  },
  {
    "path": "losses/clothes_based_adversarial_loss.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom losses.gather import GatherLayer\n\n\nclass ClothesBasedAdversarialLoss(nn.Module):\n    \"\"\" Clothes-based Adversarial Loss.\n\n    Reference:\n        Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022.\n\n    Args:\n        scale (float): scaling factor.\n        epsilon (float): a trade-off hyper-parameter.\n    \"\"\"\n    def __init__(self, scale=16, epsilon=0.1):\n        super().__init__()\n        self.scale = scale\n        self.epsilon = epsilon\n\n    def forward(self, inputs, targets, positive_mask):\n        \"\"\"\n        Args:\n            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)\n            targets: ground truth labels with shape (batch_size)\n            positive_mask: positive mask matrix with shape (batch_size, num_classes). The clothes classes with \n                the same identity as the anchor sample are defined as positive clothes classes and their mask \n                values are 1. The clothes classes with different identities from the anchor sample are defined \n                as negative clothes classes and their mask values in positive_mask are 0.\n        \"\"\"\n        inputs = self.scale * inputs\n        negtive_mask = 1 - positive_mask\n        identity_mask = torch.zeros(inputs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda()\n\n        exp_logits = torch.exp(inputs)\n        log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits)\n        log_prob = inputs - log_sum_exp_pos_and_all_neg\n\n        mask = (1 - self.epsilon) * identity_mask + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask\n        loss = (- mask * log_prob).sum(1).mean()\n\n        return loss\n\n\nclass ClothesBasedAdversarialLossWithMemoryBank(nn.Module):\n    \"\"\" Clothes-based Adversarial Loss between mini batch and the samples in memory.\n\n    Reference:\n        Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022.\n\n    Args:\n        num_clothes (int): the number of clothes classes.\n        feat_dim (int): the dimensions of feature.\n        momentum (float): momentum to update memory.\n        scale (float): scaling factor.\n        epsilon (float): a trade-off hyper-parameter.\n    \"\"\"\n    def __init__(self, num_clothes, feat_dim, momentum=0., scale=16, epsilon=0.1):\n        super().__init__()\n        self.num_clothes = num_clothes\n        self.feat_dim = feat_dim\n        self.momentum = momentum\n        self.epsilon = epsilon\n        self.scale = scale\n\n        self.register_buffer('feature_memory', torch.zeros((num_clothes, feat_dim)))\n        self.register_buffer('label_memory', torch.zeros(num_clothes, dtype=torch.int64) - 1)\n        self.has_been_filled = False\n\n    def forward(self, inputs, targets, positive_mask):\n        \"\"\"\n        Args:\n            inputs: sample features (before classifier) with shape (batch_size, feat_dim)\n            targets: ground truth labels with shape (batch_size)\n            positive_mask: positive mask matrix with shape (batch_size, num_classes). \n        \"\"\"\n        # gather all samples from different GPUs to update memory.\n        gathered_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)\n        gathered_targets = torch.cat(GatherLayer.apply(targets), dim=0)\n        self._update_memory(gathered_inputs.detach(), gathered_targets)\n\n        inputs_norm = F.normalize(inputs, p=2, dim=1)\n        memory_norm = F.normalize(self.feature_memory.detach(), p=2, dim=1)\n        similarities = torch.matmul(inputs_norm, memory_norm.t()) * self.scale\n\n        negtive_mask = 1 - positive_mask\n        mask_identity = torch.zeros(positive_mask.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda()\n\n        if not self.has_been_filled:\n            invalid_index = self.label_memory == -1\n            positive_mask[:, invalid_index] = 0\n            negtive_mask[:, invalid_index] = 0\n            if sum(invalid_index.type(torch.int)) == 0:\n                self.has_been_filled = True\n                print('Memory bank is full')\n\n        # compute log_prob\n        exp_logits = torch.exp(similarities)\n        log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits)\n        log_prob = similarities - log_sum_exp_pos_and_all_neg\n\n        # compute mean of log-likelihood over positive\n        mask = (1 - self.epsilon) * mask_identity + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask\n        loss = (- mask * log_prob).sum(1).mean()\n        \n        return loss\n\n    def _update_memory(self, features, labels):\n        label_to_feat = {}\n        for x, y in zip(features, labels):\n            if y not in label_to_feat:\n                label_to_feat[y] = [x.unsqueeze(0)]\n            else:\n                label_to_feat[y].append(x.unsqueeze(0))\n        if not self.has_been_filled:\n            for y in label_to_feat:\n                feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0)\n                self.feature_memory[y] = feat\n                self.label_memory[y] = y\n        else:\n            for y in label_to_feat:\n                feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0)\n                self.feature_memory[y] = self.momentum * self.feature_memory[y] + (1. - self.momentum) * feat\n                # self.embedding_memory[y] /= self.embedding_memory[y].norm()"
  },
  {
    "path": "losses/contrastive_loss.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch import distributed as dist\nfrom losses.gather import GatherLayer\n\n\nclass ContrastiveLoss(nn.Module):\n    \"\"\" Supervised Contrastive Learning Loss among sample pairs.\n\n    Args:\n        scale (float): scaling factor.\n    \"\"\"\n    def __init__(self, scale=16, **kwargs):\n        super().__init__()\n        self.s = scale\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        Args:\n            inputs: sample features (before classifier) with shape (batch_size, feat_dim)\n            targets: ground truth labels with shape (batch_size)\n        \"\"\"\n        # l2-normalize\n        inputs = F.normalize(inputs, p=2, dim=1)\n\n        # gather all samples from different GPUs as gallery to compute pairwise loss.\n        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)\n        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)\n        m, n = targets.size(0), gallery_targets.size(0)\n\n        # compute cosine similarity\n        similarities = torch.matmul(inputs, gallery_inputs.t()) * self.s\n        \n        # get mask for pos/neg pairs\n        targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1)\n        mask = torch.eq(targets, gallery_targets.T).float().cuda()\n        mask_self = torch.zeros_like(mask)\n        rank = dist.get_rank()\n        mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda()\n        mask_pos = mask - mask_self\n        mask_neg = 1 - mask\n\n        # compute log_prob\n        exp_logits = torch.exp(similarities) * (1 - mask_self)\n        # log_prob = similarities - torch.log(exp_logits.sum(1, keepdim=True))\n        log_sum_exp_pos_and_all_neg = torch.log((exp_logits * mask_neg).sum(1, keepdim=True) + exp_logits)\n        log_prob = similarities - log_sum_exp_pos_and_all_neg\n\n        # compute mean of log-likelihood over positive\n        loss = (mask_pos * log_prob).sum(1) / mask_pos.sum(1)\n\n        loss = - loss.mean()\n\n        return loss"
  },
  {
    "path": "losses/cosface_loss.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch import distributed as dist\nfrom losses.gather import GatherLayer\n\n\nclass CosFaceLoss(nn.Module):\n    \"\"\" CosFace Loss based on the predictions of classifier.\n\n    Reference:\n        Wang et al. CosFace: Large Margin Cosine Loss for Deep Face Recognition. In CVPR, 2018.\n\n    Args:\n        scale (float): scaling factor.\n        margin (float): pre-defined margin.\n    \"\"\"\n    def __init__(self, scale=16, margin=0.1, **kwargs):\n        super().__init__()\n        self.s = scale\n        self.m = margin\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        Args:\n            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)\n            targets: ground truth labels with shape (batch_size)\n        \"\"\"\n        one_hot = torch.zeros_like(inputs)\n        one_hot.scatter_(1, targets.view(-1, 1), 1.0)\n\n        output = self.s * (inputs - one_hot * self.m)\n\n        return F.cross_entropy(output, targets)\n\n\nclass PairwiseCosFaceLoss(nn.Module):\n    \"\"\" CosFace Loss among sample pairs.\n\n    Reference:\n        Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020.\n\n    Args:\n        scale (float): scaling factor.\n        margin (float): pre-defined margin.\n    \"\"\"\n    def __init__(self, scale=16, margin=0):\n        super().__init__()\n        self.s = scale\n        self.m = margin\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        Args:\n            inputs: sample features (before classifier) with shape (batch_size, feat_dim)\n            targets: ground truth labels with shape (batch_size)\n        \"\"\"\n        # l2-normalize\n        inputs = F.normalize(inputs, p=2, dim=1)\n\n        # gather all samples from different GPUs as gallery to compute pairwise loss.\n        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)\n        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)\n        m, n = targets.size(0), gallery_targets.size(0)\n\n        # compute cosine similarity\n        similarities = torch.matmul(inputs, gallery_inputs.t())\n        \n        # get mask for pos/neg pairs\n        targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1)\n        mask = torch.eq(targets, gallery_targets.T).float().cuda()\n        mask_self = torch.zeros_like(mask)\n        rank = dist.get_rank()\n        mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda()\n        mask_pos = mask - mask_self\n        mask_neg = 1 - mask\n\n        scores = (similarities + self.m) * mask_neg - similarities * mask_pos\n        scores = scores * self.s\n        \n        neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1)\n        pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1)\n\n        loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean()\n\n        return loss"
  },
  {
    "path": "losses/cross_entropy_loss_with_label_smooth.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass CrossEntropyWithLabelSmooth(nn.Module):\n    \"\"\" Cross entropy loss with label smoothing regularization.\n\n    Reference:\n        Szegedy et al. Rethinking the Inception Architecture for Computer Vision. In CVPR, 2016.\n    Equation: \n        y = (1 - epsilon) * y + epsilon / K.\n\n    Args:\n        epsilon (float): a hyper-parameter in the above equation.\n    \"\"\"\n    def __init__(self, epsilon=0.1):\n        super().__init__()\n        self.epsilon = epsilon\n        self.logsoftmax = nn.LogSoftmax(dim=1)\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        Args:\n            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)\n            targets: ground truth labels with shape (batch_size)\n        \"\"\"\n        _, num_classes = inputs.size()\n        log_probs = self.logsoftmax(inputs)\n        targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda()\n        targets = (1 - self.epsilon) * targets + self.epsilon / num_classes\n        loss = (- targets * log_probs).mean(0).sum()\n\n        return loss\n"
  },
  {
    "path": "losses/gather.py",
    "content": "import torch\nimport torch.distributed as dist\n\n\nclass GatherLayer(torch.autograd.Function):\n    \"\"\"Gather tensors from all process, supporting backward propagation.\"\"\"\n\n    @staticmethod\n    def forward(ctx, input):\n        ctx.save_for_backward(input)\n        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]\n        dist.all_gather(output, input)\n\n        return tuple(output)\n\n    @staticmethod\n    def backward(ctx, *grads):\n        (input,) = ctx.saved_tensors\n        grad_out = torch.zeros_like(input)\n\n        # dist.reduce_scatter(grad_out, list(grads))\n        # grad_out.div_(dist.get_world_size())\n\n        grad_out[:] = grads[dist.get_rank()]\n\n        return grad_out"
  },
  {
    "path": "losses/triplet_loss.py",
    "content": "import math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom losses.gather import GatherLayer\n\n\nclass TripletLoss(nn.Module):\n    \"\"\" Triplet loss with hard example mining.\n\n    Reference:\n        Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.\n\n    Args:\n        margin (float): pre-defined margin.\n\n    Note that we use cosine similarity, rather than Euclidean distance in the original paper.\n    \"\"\"\n    def __init__(self, margin=0.3):\n        super().__init__()\n        self.m = margin\n        self.ranking_loss = nn.MarginRankingLoss(margin=margin)\n\n    def forward(self, inputs, targets):\n        \"\"\"\n        Args:\n            inputs: sample features (before classifier) with shape (batch_size, feat_dim)\n            targets: ground truth labels with shape (batch_size)\n        \"\"\"\n        # l2-normlize\n        inputs = F.normalize(inputs, p=2, dim=1)\n\n        # gather all samples from different GPUs as gallery to compute pairwise loss.\n        gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0)\n        gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0)\n\n        # compute distance\n        dist = 1 - torch.matmul(inputs, gallery_inputs.t()) # values in [0, 2]\n\n        # get positive and negative masks\n        targets, gallery_targets = targets.view(-1,1), gallery_targets.view(-1,1)\n        mask_pos = torch.eq(targets, gallery_targets.T).float().cuda()\n        mask_neg = 1 - mask_pos\n\n        # For each anchor, find the hardest positive and negative pairs\n        dist_ap, _ = torch.max((dist - mask_neg * 99999999.), dim=1)\n        dist_an, _ = torch.min((dist + mask_pos * 99999999.), dim=1)\n\n        # Compute ranking hinge loss\n        y = torch.ones_like(dist_an)\n        loss = self.ranking_loss(dist_an, dist_ap, y)\n\n        return loss"
  },
  {
    "path": "main.py",
    "content": "import os\nimport sys\nimport time\nimport datetime\nimport argparse\nimport logging\nimport os.path as osp\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.optim import lr_scheduler\nfrom torch import distributed as dist\nfrom apex import amp\n\nfrom configs.default_img import get_img_config\nfrom configs.default_vid import get_vid_config\nfrom data import build_dataloader\nfrom models import build_model\nfrom losses import build_losses\nfrom tools.utils import save_checkpoint, set_seed, get_logger\nfrom train import train_cal, train_cal_with_memory\nfrom test import test, test_prcc\n\n\nVID_DATASET = ['ccvid']\n\n\ndef parse_option():\n    parser = argparse.ArgumentParser(description='Train clothes-changing re-id model with clothes-based adversarial loss')\n    parser.add_argument('--cfg', type=str, required=True, metavar=\"FILE\", help='path to config file')\n    # Datasets\n    parser.add_argument('--root', type=str, help=\"your root path to data directory\")\n    parser.add_argument('--dataset', type=str, default='ltcc', help=\"ltcc, prcc, vcclothes, ccvid, last, deepchange\")\n    # Miscs\n    parser.add_argument('--output', type=str, help=\"your output path to save model and logs\")\n    parser.add_argument('--resume', type=str, metavar='PATH')\n    parser.add_argument('--amp', action='store_true', help=\"automatic mixed precision\")\n    parser.add_argument('--eval', action='store_true', help=\"evaluation only\")\n    parser.add_argument('--tag', type=str, help='tag for log file')\n    parser.add_argument('--gpu', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES')\n\n    args, unparsed = parser.parse_known_args()\n    if args.dataset in VID_DATASET:\n        config = get_vid_config(args)\n    else:\n        config = get_img_config(args)\n\n    return config\n\n\ndef main(config):\n    # Build dataloader\n    if config.DATA.DATASET == 'prcc':\n        trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler = build_dataloader(config)\n    else:\n        trainloader, queryloader, galleryloader, dataset, train_sampler = build_dataloader(config)\n    # Define a matrix pid2clothes with shape (num_pids, num_clothes). \n    # pid2clothes[i, j] = 1 when j-th clothes belongs to i-th identity. Otherwise, pid2clothes[i, j] = 0.\n    pid2clothes = torch.from_numpy(dataset.pid2clothes)\n\n    # Build model\n    model, classifier, clothes_classifier = build_model(config, dataset.num_train_pids, dataset.num_train_clothes)\n    # Build identity classification loss, pairwise loss, clothes classificaiton loss, and adversarial loss.\n    criterion_cla, criterion_pair, criterion_clothes, criterion_adv = build_losses(config, dataset.num_train_clothes)\n    # Build optimizer\n    parameters = list(model.parameters()) + list(classifier.parameters())\n    if config.TRAIN.OPTIMIZER.NAME == 'adam':\n        optimizer = optim.Adam(parameters, lr=config.TRAIN.OPTIMIZER.LR, \n                               weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)\n        optimizer_cc = optim.Adam(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, \n                                  weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)\n    elif config.TRAIN.OPTIMIZER.NAME == 'adamw':\n        optimizer = optim.AdamW(parameters, lr=config.TRAIN.OPTIMIZER.LR, \n                               weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)\n        optimizer_cc = optim.AdamW(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, \n                                  weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY)\n    elif config.TRAIN.OPTIMIZER.NAME == 'sgd':\n        optimizer = optim.SGD(parameters, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, \n                              weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)\n        optimizer_cc = optim.SGD(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, \n                              weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True)\n    else:\n        raise KeyError(\"Unknown optimizer: {}\".format(config.TRAIN.OPTIMIZER.NAME))\n    # Build lr_scheduler\n    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE, \n                                         gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE)\n\n    start_epoch = config.TRAIN.START_EPOCH\n    if config.MODEL.RESUME:\n        logger.info(\"Loading checkpoint from '{}'\".format(config.MODEL.RESUME))\n        checkpoint = torch.load(config.MODEL.RESUME)\n        model.load_state_dict(checkpoint['model_state_dict'])\n        classifier.load_state_dict(checkpoint['classifier_state_dict'])\n        if config.LOSS.CAL == 'calwithmemory':\n            criterion_adv.load_state_dict(checkpoint['clothes_classifier_state_dict'])\n        else:\n            clothes_classifier.load_state_dict(checkpoint['clothes_classifier_state_dict'])\n        start_epoch = checkpoint['epoch']\n\n    local_rank = dist.get_rank()\n    model = model.cuda(local_rank)\n    classifier = classifier.cuda(local_rank)\n    if config.LOSS.CAL == 'calwithmemory':\n        criterion_adv = criterion_adv.cuda(local_rank)\n    else:\n        clothes_classifier = clothes_classifier.cuda(local_rank)\n    torch.cuda.set_device(local_rank)\n\n    if config.TRAIN.AMP:\n        [model, classifier], optimizer = amp.initialize([model, classifier], optimizer, opt_level=\"O1\")\n        if config.LOSS.CAL != 'calwithmemory':\n            clothes_classifier, optimizer_cc = amp.initialize(clothes_classifier, optimizer_cc, opt_level=\"O1\")\n\n    model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)\n    classifier = nn.parallel.DistributedDataParallel(classifier, device_ids=[local_rank], output_device=local_rank)\n    if config.LOSS.CAL != 'calwithmemory':\n        clothes_classifier = nn.parallel.DistributedDataParallel(clothes_classifier, device_ids=[local_rank], output_device=local_rank)\n\n    if config.EVAL_MODE:\n        logger.info(\"Evaluate only\")\n        with torch.no_grad():\n            if config.DATA.DATASET == 'prcc':\n                test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)\n            else:\n                test(config, model, queryloader, galleryloader, dataset)\n        return\n\n    start_time = time.time()\n    train_time = 0\n    best_rank1 = -np.inf\n    best_epoch = 0\n    logger.info(\"==> Start training\")\n    for epoch in range(start_epoch, config.TRAIN.MAX_EPOCH):\n        train_sampler.set_epoch(epoch)\n        start_train_time = time.time()\n        if config.LOSS.CAL == 'calwithmemory':\n            train_cal_with_memory(config, epoch, model, classifier, criterion_cla, criterion_pair, \n                criterion_adv, optimizer, trainloader, pid2clothes)\n        else:\n            train_cal(config, epoch, model, classifier, clothes_classifier, criterion_cla, criterion_pair, \n                criterion_clothes, criterion_adv, optimizer, optimizer_cc, trainloader, pid2clothes)\n        train_time += round(time.time() - start_train_time)        \n        \n        if (epoch+1) > config.TEST.START_EVAL and config.TEST.EVAL_STEP > 0 and \\\n            (epoch+1) % config.TEST.EVAL_STEP == 0 or (epoch+1) == config.TRAIN.MAX_EPOCH:\n            logger.info(\"==> Test\")\n            torch.cuda.empty_cache()\n            if config.DATA.DATASET == 'prcc':\n                rank1 = test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset)\n            else:\n                rank1 = test(config, model, queryloader, galleryloader, dataset)\n            torch.cuda.empty_cache()\n            is_best = rank1 > best_rank1\n            if is_best:\n                best_rank1 = rank1\n                best_epoch = epoch + 1\n\n            model_state_dict = model.module.state_dict()\n            classifier_state_dict = classifier.module.state_dict()\n            if config.LOSS.CAL == 'calwithmemory':\n                clothes_classifier_state_dict = criterion_adv.state_dict()\n            else:\n                clothes_classifier_state_dict = clothes_classifier.module.state_dict()\n            if local_rank == 0:\n                save_checkpoint({\n                    'model_state_dict': model_state_dict,\n                    'classifier_state_dict': classifier_state_dict,\n                    'clothes_classifier_state_dict': clothes_classifier_state_dict,\n                    'rank1': rank1,\n                    'epoch': epoch,\n                }, is_best, osp.join(config.OUTPUT, 'checkpoint_ep' + str(epoch+1) + '.pth.tar'))\n        scheduler.step()\n\n    logger.info(\"==> Best Rank-1 {:.1%}, achieved at epoch {}\".format(best_rank1, best_epoch))\n\n    elapsed = round(time.time() - start_time)\n    elapsed = str(datetime.timedelta(seconds=elapsed))\n    train_time = str(datetime.timedelta(seconds=train_time))\n    logger.info(\"Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.\".format(elapsed, train_time))\n    \n\nif __name__ == '__main__':\n    config = parse_option()\n    # Set GPU\n    os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU\n    # Init dist\n    dist.init_process_group(backend=\"nccl\", init_method='env://')\n    local_rank = dist.get_rank()\n    # Set random seed\n    set_seed(config.SEED + local_rank)\n    # get logger\n    if not config.EVAL_MODE:\n        output_file = osp.join(config.OUTPUT, 'log_train_.log')\n    else:\n        output_file = osp.join(config.OUTPUT, 'log_test.log')\n    logger = get_logger(output_file, local_rank, 'reid')\n    logger.info(\"Config:\\n-----------------------------------------\")\n    logger.info(config)\n    logger.info(\"-----------------------------------------\")\n\n    main(config)"
  },
  {
    "path": "models/__init__.py",
    "content": "import logging\nfrom models.classifier import Classifier, NormalizedClassifier\nfrom models.img_resnet import ResNet50\nfrom models.vid_resnet import C2DResNet50, I3DResNet50, AP3DResNet50, NLResNet50, AP3DNLResNet50\n\n\n__factory = {\n    'resnet50': ResNet50,\n    'c2dres50': C2DResNet50,\n    'i3dres50': I3DResNet50,\n    'ap3dres50': AP3DResNet50,\n    'nlres50': NLResNet50,\n    'ap3dnlres50': AP3DNLResNet50,\n}\n\n\ndef build_model(config, num_identities, num_clothes):\n    logger = logging.getLogger('reid.model')\n    # Build backbone\n    logger.info(\"Initializing model: {}\".format(config.MODEL.NAME))\n    if config.MODEL.NAME not in __factory.keys():\n        raise KeyError(\"Invalid model: '{}'\".format(config.MODEL.NAME))\n    else:\n        logger.info(\"Init model: '{}'\".format(config.MODEL.NAME))\n        model = __factory[config.MODEL.NAME](config)\n    logger.info(\"Model size: {:.5f}M\".format(sum(p.numel() for p in model.parameters())/1000000.0))\n\n    # Build classifier\n    if config.LOSS.CLA_LOSS in ['crossentropy', 'crossentropylabelsmooth']:\n        identity_classifier = Classifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities)\n    else:\n        identity_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities)\n\n    clothes_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_clothes)\n\n    return model, identity_classifier, clothes_classifier"
  },
  {
    "path": "models/classifier.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom torch.nn import Parameter\n\n\n__all__ = ['Classifier', 'NormalizedClassifier']\n\n\nclass Classifier(nn.Module):\n    def __init__(self, feature_dim, num_classes):\n        super().__init__()\n        self.classifier = nn.Linear(feature_dim, num_classes)\n        init.normal_(self.classifier.weight.data, std=0.001)\n        init.constant_(self.classifier.bias.data, 0.0)\n\n    def forward(self, x):\n        y = self.classifier(x)\n\n        return y\n        \n\nclass NormalizedClassifier(nn.Module):\n    def __init__(self, feature_dim, num_classes):\n        super().__init__()\n        self.weight = Parameter(torch.Tensor(num_classes, feature_dim))\n        self.weight.data.uniform_(-1, 1).renorm_(2,0,1e-5).mul_(1e5) \n\n    def forward(self, x):\n        w = self.weight  \n\n        x = F.normalize(x, p=2, dim=1)\n        w = F.normalize(w, p=2, dim=1)\n\n        return F.linear(x, w)\n\n\n\n"
  },
  {
    "path": "models/img_resnet.py",
    "content": "import torchvision\nfrom torch import nn\nfrom torch.nn import init\nfrom models.utils import pooling\n        \n\nclass ResNet50(nn.Module):\n    def __init__(self, config, **kwargs):\n        super().__init__()\n\n        resnet50 = torchvision.models.resnet50(pretrained=True)\n        if config.MODEL.RES4_STRIDE == 1:\n            resnet50.layer4[0].conv2.stride=(1, 1)\n            resnet50.layer4[0].downsample[0].stride=(1, 1) \n        self.base = nn.Sequential(*list(resnet50.children())[:-2])\n\n        if config.MODEL.POOLING.NAME == 'avg':\n            self.globalpooling = nn.AdaptiveAvgPool2d(1)\n        elif config.MODEL.POOLING.NAME == 'max':\n            self.globalpooling = nn.AdaptiveMaxPool2d(1)\n        elif config.MODEL.POOLING.NAME == 'gem':\n            self.globalpooling = pooling.GeMPooling(p=config.MODEL.POOLING.P)\n        elif config.MODEL.POOLING.NAME == 'maxavg':\n            self.globalpooling = pooling.MaxAvgPooling()\n        else:\n            raise KeyError(\"Invalid pooling: '{}'\".format(config.MODEL.POOLING.NAME))\n\n        self.bn = nn.BatchNorm1d(config.MODEL.FEATURE_DIM)\n        init.normal_(self.bn.weight.data, 1.0, 0.02)\n        init.constant_(self.bn.bias.data, 0.0)\n        \n    def forward(self, x):\n        x = self.base(x)\n        x = self.globalpooling(x)\n        x = x.view(x.size(0), -1)\n        f = self.bn(x)\n\n        return f"
  },
  {
    "path": "models/utils/c3d_blocks.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass APM(nn.Module):\n    def __init__(self, in_channels, out_channels, time_dim=3, temperature=4, contrastive_att=True):\n        super(APM, self).__init__()\n\n        self.time_dim = time_dim \n        self.temperature = temperature\n        self.contrastive_att = contrastive_att\n\n        padding = (0, 0, 0, 0, (time_dim-1)//2, (time_dim-1)//2)\n        self.padding = nn.ConstantPad3d(padding, value=0)\n\n        self.semantic_mapping = nn.Conv3d(in_channels, out_channels, \\\n                                          kernel_size=1, bias=False)          \n        if self.contrastive_att:  \n            self.x_mapping = nn.Conv3d(in_channels, out_channels, \\\n                                       kernel_size=1, bias=False)\n            self.n_mapping = nn.Conv3d(in_channels, out_channels, \\\n                                       kernel_size=1, bias=False)\n            self.contrastive_att_net = nn.Sequential(nn.Conv3d(out_channels, 1, \\\n                                kernel_size=1, bias=False), nn.Sigmoid())\n\n    def forward(self, x):\n        b, c, t, h, w = x.size()\n        N = self.time_dim\n\n        neighbor_time_index = torch.cat([(torch.arange(0,t)+i).unsqueeze(0) for i in range(N) if i!=N//2], dim=0).t().flatten().long()\n\n        # feature map registration\n        semantic = self.semantic_mapping(x) # (b, c/16, t, h, w)\n        x_norm = F.normalize(semantic, p=2, dim=1) # (b, c/16, t, h, w)\n        x_norm_padding = self.padding(x_norm) # (b, c/16, t+2, h, w)\n        x_norm_expand = x_norm.unsqueeze(3).expand(-1, -1, -1, N-1, -1, -1).permute(0, 2, 3, 4, 5, 1).contiguous().view(-1, h*w, c//16) # (b*t*2, h*w, c/16) \n        neighbor_norm = x_norm_padding[:, :, neighbor_time_index, :, :].permute(0, 2, 1, 3, 4).contiguous().view(-1, c//16, h*w) # (b*t*2, c/16, h*w) \n\n        similarity = torch.matmul(x_norm_expand, neighbor_norm) * self.temperature # (b*t*2, h*w, h*w)\n        similarity = F.softmax(similarity, dim=-1) # (b*t*2, h*w, h*w)\n\n        x_padding = self.padding(x)\n        neighbor = x_padding[:, :, neighbor_time_index, :, :].permute(0, 2, 3, 4, 1).contiguous().view(-1, h*w, c)\n        neighbor_new = torch.matmul(similarity, neighbor).view(b, t*(N-1), h, w, c).permute(0, 4, 1, 2, 3) # (b, c, t*2, h, w)\n\n        # contrastive attention\n        if self.contrastive_att:\n            x_att = self.x_mapping(x.unsqueeze(3).expand(-1, -1, -1, N-1, -1, -1).contiguous().view(b, c, (N-1)*t, h, w).detach())\n            n_att = self.n_mapping(neighbor_new.detach())\n            contrastive_att = self.contrastive_att_net(x_att * n_att)    \n            neighbor_new = neighbor_new * contrastive_att\n\n        # integrating feature maps\n        x_offset = torch.zeros([b, c, N*t, h, w], dtype=x.data.dtype, device=x.device.type)\n        x_index = torch.tensor([i for i in range(t*N) if i%N==N//2])\n        neighbor_index = torch.tensor([i for i in range(t*N) if i%N!=N//2])\n        x_offset[:, :, x_index, :, :] += x\n        x_offset[:, :, neighbor_index, :, :] += neighbor_new\n\n        return x_offset\n\n\nclass C2D(nn.Module):\n    def __init__(self, conv2d, **kwargs):\n        super(C2D, self).__init__()\n\n        # conv3d kernel\n        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])\n        stride = (1, conv2d.stride[0], conv2d.stride[0])\n        padding = (0, conv2d.padding[0], conv2d.padding[1])\n        self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                kernel_size=kernel_dim, padding=padding, \\\n                                stride=stride, bias=conv2d.bias)\n\n        # init the parameters of conv3d\n        weight_2d = conv2d.weight.data\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2)\n        weight_3d[:, :, 0, :, :] = weight_2d\n        self.conv3d.weight = nn.Parameter(weight_3d)\n        self.conv3d.bias = conv2d.bias\n\n    def forward(self, x):\n        out = self.conv3d(x)\n\n        return out\n\n\nclass I3D(nn.Module):\n    def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):\n        super(I3D, self).__init__()\n\n        # conv3d kernel\n        kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])\n        stride = (time_stride, conv2d.stride[0], conv2d.stride[0])\n        padding = (time_dim//2, conv2d.padding[0], conv2d.padding[1])\n        self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                kernel_size=kernel_dim, padding=padding, \\\n                                stride=stride, bias=conv2d.bias)\n\n        # init the parameters of conv3d\n        weight_2d = conv2d.weight.data\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)\n        middle_idx = time_dim // 2\n        weight_3d[:, :, middle_idx, :, :] = weight_2d\n        self.conv3d.weight = nn.Parameter(weight_3d)\n        self.conv3d.bias = conv2d.bias\n\n    def forward(self, x):\n        out = self.conv3d(x)\n\n        return out\n\n\nclass API3D(nn.Module):\n    def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):\n        super(API3D, self).__init__()\n\n        self.APM = APM(conv2d.in_channels, conv2d.in_channels//16, \\\n                       time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)\n        \n        # conv3d kernel\n        kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])\n        stride = (time_stride*time_dim, conv2d.stride[0], conv2d.stride[0])\n        padding = (0, conv2d.padding[0], conv2d.padding[1])\n        self.conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                kernel_size=kernel_dim, padding=padding, \\\n                                stride=stride, bias=conv2d.bias)\n\n        # init the parameters of conv3d\n        weight_2d = conv2d.weight.data\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)\n        middle_idx = time_dim // 2\n        weight_3d[:, :, middle_idx, :, :] = weight_2d\n        self.conv3d.weight = nn.Parameter(weight_3d)\n        self.conv3d.bias = conv2d.bias\n\n    def forward(self, x):\n        x_offset = self.APM(x)\n        out = self.conv3d(x_offset)\n\n        return out\n\n\nclass P3DA(nn.Module):\n    def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):\n        super(P3DA, self).__init__()\n\n        # spatial conv3d kernel\n        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])\n        stride = (1, conv2d.stride[0], conv2d.stride[0])\n        padding = (0, conv2d.padding[0], conv2d.padding[1])\n        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                        kernel_size=kernel_dim, padding=padding, \\\n                                        stride=stride, bias=conv2d.bias)\n\n        # init the parameters of spatial_conv3d\n        weight_2d = conv2d.weight.data\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2)\n        weight_3d[:, :, 0, :, :] = weight_2d\n        self.spatial_conv3d.weight = nn.Parameter(weight_3d)\n        self.spatial_conv3d.bias = conv2d.bias\n\n\n        # temporal conv3d kernel\n        kernel_dim = (time_dim, 1, 1)\n        stride = (time_stride, 1, 1)\n        padding = (time_dim//2, 0, 0)\n        self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \\\n                                         kernel_size=kernel_dim, padding=padding, \\\n                                         stride=stride, bias=False)\n\n        # init the parameters of temporal_conv3d\n        weight_2d = torch.eye(conv2d.out_channels).unsqueeze(2).unsqueeze(2)\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)\n        middle_idx = time_dim // 2\n        weight_3d[:, :, middle_idx, :, :] = weight_2d\n        self.temporal_conv3d.weight = nn.Parameter(weight_3d)\n\n\n    def forward(self, x):\n        x = self.spatial_conv3d(x)\n        out = self.temporal_conv3d(x)\n\n        return out\n\n\nclass P3DB(nn.Module):\n    def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):\n        super(P3DB, self).__init__()\n\n        # spatial conv3d kernel\n        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])\n        stride = (1, conv2d.stride[0], conv2d.stride[0])\n        padding = (0, conv2d.padding[0], conv2d.padding[1])\n        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                        kernel_size=kernel_dim, padding=padding, \\\n                                        stride=stride, bias=conv2d.bias)\n\n        # init the parameters of spatial_conv3d\n        weight_2d = conv2d.weight.data\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2)\n        weight_3d[:, :, 0, :, :] = weight_2d\n        self.spatial_conv3d.weight = nn.Parameter(weight_3d)\n        self.spatial_conv3d.bias = conv2d.bias\n\n\n        # temporal conv3d kernel\n        kernel_dim = (time_dim, 1, 1)\n        stride = (time_stride, conv2d.stride[0], conv2d.stride[0])\n        padding = (time_dim//2, 0, 0)\n        self.temporal_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                         kernel_size=kernel_dim, padding=padding, \\\n                                         stride=stride, bias=False)\n\n        # init the parameters of temporal_conv3d\n        nn.init.constant_(self.temporal_conv3d.weight, 0)\n\n\n    def forward(self, x):\n        # print(x.shape)\n        out1 = self.spatial_conv3d(x)\n        # print(out1.shape)\n        out2 = self.temporal_conv3d(x)\n        # print(out2.shape)\n        out = out1 + out2\n\n        return out\n\n\nclass P3DC(nn.Module):\n    def __init__(self, conv2d, time_dim=3, time_stride=1, **kwargs):\n        super(P3DC, self).__init__()\n\n        # spatial conv3d kernel\n        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])\n        stride = (1, conv2d.stride[0], conv2d.stride[0])\n        padding = (0, conv2d.padding[0], conv2d.padding[1])\n        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                        kernel_size=kernel_dim, padding=padding, \\\n                                        stride=stride, bias=conv2d.bias)\n\n        # init the parameters of spatial_conv3d\n        weight_2d = conv2d.weight.data\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2)\n        weight_3d[:, :, 0, :, :] = weight_2d\n        self.spatial_conv3d.weight = nn.Parameter(weight_3d)\n        self.spatial_conv3d.bias = conv2d.bias\n\n\n        # temporal conv3d kernel\n        kernel_dim = (time_dim, 1, 1)\n        stride = (time_stride, 1, 1)\n        padding = (time_dim//2, 0, 0)\n        self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \\\n                                         kernel_size=kernel_dim, padding=padding, \\\n                                         stride=stride, bias=False)\n\n        # init the parameters of temporal_conv3d\n        nn.init.constant_(self.temporal_conv3d.weight, 0)\n\n\n    def forward(self, x):\n        out = self.spatial_conv3d(x)\n        residual = self.temporal_conv3d(out)\n        out = out + residual\n\n        return out\n\n\nclass APP3DA(nn.Module):\n    def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):\n        super(APP3DA, self).__init__()\n\n        self.APM = APM(conv2d.out_channels, conv2d.out_channels//16, \\\n                       time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)\n\n        # spatial conv3d kernel\n        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])\n        stride = (1, conv2d.stride[0], conv2d.stride[0])\n        padding = (0, conv2d.padding[0], conv2d.padding[1])\n        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                        kernel_size=kernel_dim, padding=padding, \\\n                                        stride=stride, bias=conv2d.bias)\n\n        # init the parameters of spatial_conv3d\n        weight_2d = conv2d.weight.data\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2)\n        weight_3d[:, :, 0, :, :] = weight_2d\n        self.spatial_conv3d.weight = nn.Parameter(weight_3d)\n        self.spatial_conv3d.bias = conv2d.bias\n\n\n        # temporal conv3d kernel\n        kernel_dim = (time_dim, 1, 1)\n        stride = (time_stride*time_dim, 1, 1)\n        padding = (0, 0, 0)\n        self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \\\n                                         kernel_size=kernel_dim, padding=padding, \\\n                                         stride=stride, bias=False)\n\n        # init the parameters of temporal_conv3d\n        weight_2d = torch.eye(conv2d.out_channels).unsqueeze(2).unsqueeze(2)\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)\n        middle_idx = time_dim // 2\n        weight_3d[:, :, middle_idx, :, :] = weight_2d\n        self.temporal_conv3d.weight = nn.Parameter(weight_3d)\n\n\n    def forward(self, x):\n        x = self.spatial_conv3d(x)\n        out = self.temporal_conv3d(self.APM(x))\n\n        return out\n\n\nclass APP3DB(nn.Module):\n    def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):\n        super(APP3DB, self).__init__()\n\n        self.APM = APM(conv2d.in_channels, conv2d.in_channels//16, \\\n                       time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)\n\n        # spatial conv3d kernel\n        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])\n        stride = (1, conv2d.stride[0], conv2d.stride[0])\n        padding = (0, conv2d.padding[0], conv2d.padding[1])\n        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                        kernel_size=kernel_dim, padding=padding, \\\n                                        stride=stride, bias=conv2d.bias)\n\n        # init the parameters of spatial_conv3d\n        weight_2d = conv2d.weight.data\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2)\n        weight_3d[:, :, 0, :, :] = weight_2d\n        self.spatial_conv3d.weight = nn.Parameter(weight_3d)\n        self.spatial_conv3d.bias = conv2d.bias\n\n\n        # temporal conv3d kernel\n        kernel_dim = (time_dim, 1, 1)\n        stride = (time_stride*time_dim, conv2d.stride[0], conv2d.stride[0])\n        padding = (0, 0, 0)\n        self.temporal_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                         kernel_size=kernel_dim, padding=padding, \\\n                                         stride=stride, bias=False)\n\n        # init the parameters of temporal_conv3d\n        nn.init.constant_(self.temporal_conv3d.weight, 0)\n\n\n    def forward(self, x):\n        out1 = self.spatial_conv3d(x)\n        out2 = self.temporal_conv3d(self.APM(x))\n        out = out1 + out2\n\n        return out\n\n\nclass APP3DC(nn.Module):\n    def __init__(self, conv2d, time_dim=3, time_stride=1, temperature=4, contrastive_att=True):\n        super(APP3DC, self).__init__()\n\n        self.APM = APM(conv2d.out_channels, conv2d.out_channels//16, \\\n                       time_dim=time_dim, temperature=temperature, contrastive_att=contrastive_att)\n\n        # spatial conv3d kernel\n        kernel_dim = (1, conv2d.kernel_size[0], conv2d.kernel_size[1])\n        stride = (1, conv2d.stride[0], conv2d.stride[0])\n        padding = (0, conv2d.padding[0], conv2d.padding[1])\n        self.spatial_conv3d = nn.Conv3d(conv2d.in_channels, conv2d.out_channels, \\\n                                        kernel_size=kernel_dim, padding=padding, \\\n                                        stride=stride, bias=conv2d.bias)\n\n        # init the parameters of spatial_conv3d\n        weight_2d = conv2d.weight.data\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2)\n        weight_3d[:, :, 0, :, :] = weight_2d\n        self.spatial_conv3d.weight = nn.Parameter(weight_3d)\n        self.spatial_conv3d.bias = conv2d.bias\n\n\n        # temporal conv3d kernel\n        kernel_dim = (time_dim, 1, 1)\n        stride = (time_stride*time_dim, 1, 1)\n        padding = (0, 0, 0)\n        self.temporal_conv3d = nn.Conv3d(conv2d.out_channels, conv2d.out_channels, \\\n                                         kernel_size=kernel_dim, padding=padding, \\\n                                         stride=stride, bias=False)\n\n        # init the parameters of temporal_conv3d\n        nn.init.constant_(self.temporal_conv3d.weight, 0)\n\n\n    def forward(self, x):\n        out = self.spatial_conv3d(x)\n        residual = self.temporal_conv3d(self.APM(out))\n        out = out + residual\n\n        return out\n"
  },
  {
    "path": "models/utils/inflate.py",
    "content": "# inflate 2D modules to 3D modules\nimport torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\n\ndef inflate_conv(conv2d,\n                 time_dim=1,\n                 time_padding=0,\n                 time_stride=1,\n                 time_dilation=1,\n                 center=False):\n    # To preserve activations, padding should be by continuity and not zero\n    # or no padding in time dimension\n    kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1])\n    padding = (time_padding, conv2d.padding[0], conv2d.padding[1])\n    stride = (time_stride, conv2d.stride[0], conv2d.stride[0])\n    dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1])\n    conv3d = nn.Conv3d(\n        conv2d.in_channels,\n        conv2d.out_channels,\n        kernel_dim,\n        padding=padding,\n        dilation=dilation,\n        stride=stride)\n    # Repeat filter time_dim times along time dimension\n    weight_2d = conv2d.weight.data\n    if center:\n        weight_3d = torch.zeros(*weight_2d.shape)\n        weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)\n        middle_idx = time_dim // 2\n        weight_3d[:, :, middle_idx, :, :] = weight_2d\n    else:\n        weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)\n        weight_3d = weight_3d / time_dim\n\n    # Assign new params\n    conv3d.weight = nn.Parameter(weight_3d)\n    conv3d.bias = conv2d.bias\n    return conv3d\n\n\ndef inflate_linear(linear2d, time_dim):\n    \"\"\"\n    Args:\n        time_dim: final time dimension of the features\n    \"\"\"\n    linear3d = nn.Linear(linear2d.in_features * time_dim,\n                               linear2d.out_features)\n    weight3d = linear2d.weight.data.repeat(1, time_dim)\n    weight3d = weight3d / time_dim\n\n    linear3d.weight = nn.Parameter(weight3d)\n    linear3d.bias = linear2d.bias\n    return linear3d\n\n\ndef inflate_batch_norm(batch2d):\n    # In pytorch 0.2.0 the 2d and 3d versions of batch norm\n    # work identically except for the check that verifies the\n    # input dimensions\n\n    batch3d = nn.BatchNorm3d(batch2d.num_features)\n    # retrieve 3d _check_input_dim function\n    batch2d._check_input_dim = batch3d._check_input_dim\n    return batch2d\n\n\ndef inflate_pool(pool2d,\n                 time_dim=1,\n                 time_padding=0,\n                 time_stride=None,\n                 time_dilation=1):\n    kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size)\n    padding = (time_padding, pool2d.padding, pool2d.padding)\n    if time_stride is None:\n        time_stride = time_dim\n    stride = (time_stride, pool2d.stride, pool2d.stride)\n    if isinstance(pool2d, nn.MaxPool2d):\n        dilation = (time_dilation, pool2d.dilation, pool2d.dilation)\n        pool3d = nn.MaxPool3d(\n            kernel_dim,\n            padding=padding,\n            dilation=dilation,\n            stride=stride,\n            ceil_mode=pool2d.ceil_mode)\n    elif isinstance(pool2d, nn.AvgPool2d):\n        pool3d = nn.AvgPool3d(kernel_dim, stride=stride)\n    else:\n        raise ValueError(\n            '{} is not among known pooling classes'.format(type(pool2d)))\n    return pool3d\n\n\nclass MaxPool2dFor3dInput(nn.Module):\n    \"\"\"\n    Since nn.MaxPool3d is nondeterministic operation, using fixed random seeds can't get consistent results.\n    So we attempt to use max_pool2d to implement MaxPool3d with kernelsize (1, kernel_size, kernel_size).\n    \"\"\"\n    def __init__(self, kernel_size, stride=None, padding=0, dilation=1):\n        super().__init__()\n        self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)\n    def forward(self, x):\n        b, c, t, h, w = x.size()\n        x = x.permute(0, 2, 1, 3, 4).contiguous() # b, t, c, h, w\n        x = x.view(b*t, c, h, w)\n        # max pooling\n        x = self.maxpool(x)\n        _, _, h, w = x.size()\n        x = x.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous()\n\n        return x"
  },
  {
    "path": "models/utils/nonlocal_blocks.py",
    "content": "import torch\nimport math\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom models.utils import inflate\n\n\nclass NonLocalBlockND(nn.Module):\n    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):\n        super(NonLocalBlockND, self).__init__()\n\n        assert dimension in [1, 2, 3]\n\n        self.dimension = dimension\n        self.sub_sample = sub_sample\n        self.in_channels = in_channels\n        self.inter_channels = inter_channels\n\n        if self.inter_channels is None:\n            self.inter_channels = in_channels // 2\n            if self.inter_channels == 0:\n                self.inter_channels = 1\n\n        if dimension == 3:\n            conv_nd = nn.Conv3d\n            # max_pool = inflate.MaxPool2dFor3dInput\n            max_pool = nn.MaxPool3d\n            bn = nn.BatchNorm3d\n        elif dimension == 2:\n            conv_nd = nn.Conv2d\n            max_pool = nn.MaxPool2d\n            bn = nn.BatchNorm2d\n        else:\n            conv_nd = nn.Conv1d\n            max_pool = nn.MaxPool1d\n            bn = nn.BatchNorm1d\n\n        self.g = conv_nd(self.in_channels, self.inter_channels,\n                         kernel_size=1, stride=1, padding=0, bias=True)\n        self.theta = conv_nd(self.in_channels, self.inter_channels,\n                             kernel_size=1, stride=1, padding=0, bias=True)\n        self.phi = conv_nd(self.in_channels, self.inter_channels,\n                           kernel_size=1, stride=1, padding=0, bias=True)\n        # if sub_sample:\n        #     self.g = nn.Sequential(self.g, max_pool(kernel_size=2))\n        #     self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))\n        if sub_sample:\n            if dimension == 3:\n                self.g = nn.Sequential(self.g, max_pool((1, 2, 2)))\n                self.phi = nn.Sequential(self.phi, max_pool((1, 2, 2)))\n            else:\n                self.g = nn.Sequential(self.g, max_pool(kernel_size=2))\n                self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))\n\n        if bn_layer:\n            self.W = nn.Sequential(\n                conv_nd(self.inter_channels, self.in_channels,\n                        kernel_size=1, stride=1, padding=0, bias=True),\n                bn(self.in_channels)\n            )\n        else:\n            self.W = conv_nd(self.inter_channels, self.in_channels,\n                             kernel_size=1, stride=1, padding=0, bias=True)\n        \n        # init\n        for m in self.modules():\n            if isinstance(m, conv_nd):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, bn):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n        if bn_layer:\n            nn.init.constant_(self.W[1].weight.data, 0.0)\n            nn.init.constant_(self.W[1].bias.data, 0.0)\n        else:\n            nn.init.constant_(self.W.weight.data, 0.0)\n            nn.init.constant_(self.W.bias.data, 0.0)\n\n\n    def forward(self, x):\n        '''\n        :param x: (b, c, t, h, w)\n        :return:\n        '''\n        batch_size = x.size(0)\n\n        g_x = self.g(x).view(batch_size, self.inter_channels, -1)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)\n        theta_x = theta_x.permute(0, 2, 1)\n        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)\n        f = torch.matmul(theta_x, phi_x)\n        f = F.softmax(f, dim=-1)\n\n        y = torch.matmul(f, g_x)\n        y = y.permute(0, 2, 1).contiguous()\n        y = y.view(batch_size, self.inter_channels, *x.size()[2:])\n        y = self.W(y)\n        z = y + x\n\n        return z\n\n\nclass NonLocalBlock1D(NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NonLocalBlock1D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=1, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NonLocalBlock2D(NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NonLocalBlock2D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=2, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n\n\nclass NonLocalBlock3D(NonLocalBlockND):\n    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):\n        super(NonLocalBlock3D, self).__init__(in_channels,\n                                              inter_channels=inter_channels,\n                                              dimension=3, sub_sample=sub_sample,\n                                              bn_layer=bn_layer)\n"
  },
  {
    "path": "models/utils/pooling.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\n\nclass GeMPooling(nn.Module):\n    def __init__(self, p=3, eps=1e-6):\n        super().__init__()\n        self.p = nn.Parameter(torch.ones(1) * p)\n        self.eps = eps\n\n    def forward(self, x):\n        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), x.size()[2:]).pow(1./self.p)\n\n\nclass MaxAvgPooling(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.maxpooling = nn.AdaptiveMaxPool2d(1)\n        self.avgpooling = nn.AdaptiveAvgPool2d(1)\n\n    def forward(self, x):\n        max_f = self.maxpooling(x)\n        avg_f = self.avgpooling(x)\n\n        return torch.cat((max_f, avg_f), 1)\n        "
  },
  {
    "path": "models/vid_resnet.py",
    "content": "import torchvision\nimport torch.nn as nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom models.utils import inflate\nfrom models.utils import c3d_blocks\nfrom models.utils import nonlocal_blocks\n\n\n__all__ = ['AP3DResNet50', 'AP3DNLResNet50', 'NLResNet50', 'C2DResNet50', \n           'I3DResNet50', \n          ] \n\n\nclass Bottleneck3D(nn.Module):\n    def __init__(self, bottleneck2d, block, inflate_time=False, temperature=4, contrastive_att=True):\n        super().__init__()\n        self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=1)\n        self.bn1 = inflate.inflate_batch_norm(bottleneck2d.bn1)\n        if inflate_time == True:\n            self.conv2 = block(bottleneck2d.conv2, temperature=temperature, contrastive_att=contrastive_att)\n        else:\n            self.conv2 = inflate.inflate_conv(bottleneck2d.conv2, time_dim=1)\n        self.bn2 = inflate.inflate_batch_norm(bottleneck2d.bn2)\n        self.conv3 = inflate.inflate_conv(bottleneck2d.conv3, time_dim=1)\n        self.bn3 = inflate.inflate_batch_norm(bottleneck2d.bn3)\n        self.relu = nn.ReLU(inplace=True)\n\n        if bottleneck2d.downsample is not None:\n            self.downsample = self._inflate_downsample(bottleneck2d.downsample)\n        else:\n            self.downsample = None\n\n    def _inflate_downsample(self, downsample2d, time_stride=1):\n        downsample3d = nn.Sequential(\n            inflate.inflate_conv(downsample2d[0], time_dim=1, \n                                 time_stride=time_stride),\n            inflate.inflate_batch_norm(downsample2d[1]))\n        return downsample3d\n\n    def forward(self, x):\n        residual = x\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet503D(nn.Module):\n    def __init__(self, config, block, c3d_idx, nl_idx, **kwargs):\n        super().__init__()\n        self.block = block\n        self.temperature = config.MODEL.AP3D.TEMPERATURE\n        self.contrastive_att = config.MODEL.AP3D.CONTRACTIVE_ATT\n\n        resnet2d = torchvision.models.resnet50(pretrained=True)\n        if config.MODEL.RES4_STRIDE == 1:\n            resnet2d.layer4[0].conv2.stride=(1, 1)\n            resnet2d.layer4[0].downsample[0].stride=(1, 1) \n\n        self.conv1 = inflate.inflate_conv(resnet2d.conv1, time_dim=1)\n        self.bn1 = inflate.inflate_batch_norm(resnet2d.bn1)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = inflate.inflate_pool(resnet2d.maxpool, time_dim=1)\n        # self.maxpool = inflate.MaxPool2dFor3dInput(kernel_size=resnet2d.maxpool.kernel_size,\n        #                                            stride=resnet2d.maxpool.stride,\n        #                                            padding=resnet2d.maxpool.padding,\n        #                                            dilation=resnet2d.maxpool.dilation)\n\n        self.layer1 = self._inflate_reslayer(resnet2d.layer1, c3d_idx=c3d_idx[0], \\\n                                             nonlocal_idx=nl_idx[0], nonlocal_channels=256)\n        self.layer2 = self._inflate_reslayer(resnet2d.layer2, c3d_idx=c3d_idx[1], \\\n                                             nonlocal_idx=nl_idx[1], nonlocal_channels=512)\n        self.layer3 = self._inflate_reslayer(resnet2d.layer3, c3d_idx=c3d_idx[2], \\\n                                             nonlocal_idx=nl_idx[2], nonlocal_channels=1024)\n        self.layer4 = self._inflate_reslayer(resnet2d.layer4, c3d_idx=c3d_idx[3], \\\n                                             nonlocal_idx=nl_idx[3], nonlocal_channels=2048)\n\n        self.bn = nn.BatchNorm1d(2048)\n        init.normal_(self.bn.weight.data, 1.0, 0.02)\n        init.constant_(self.bn.bias.data, 0.0)\n\n    def _inflate_reslayer(self, reslayer2d, c3d_idx, nonlocal_idx=[], nonlocal_channels=0):\n        reslayers3d = []\n        for i,layer2d in enumerate(reslayer2d):\n            if i not in c3d_idx:\n                layer3d = Bottleneck3D(layer2d, c3d_blocks.C2D, inflate_time=False)\n            else:\n                layer3d = Bottleneck3D(layer2d, self.block, inflate_time=True, \\\n                                       temperature=self.temperature, contrastive_att=self.contrastive_att)\n            reslayers3d.append(layer3d)\n\n            if i in nonlocal_idx:\n                non_local_block = nonlocal_blocks.NonLocalBlock3D(nonlocal_channels, sub_sample=True)\n                reslayers3d.append(non_local_block)\n\n        return nn.Sequential(*reslayers3d)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        b, c, t, h, w = x.size()\n        x = x.permute(0, 2, 1, 3, 4).contiguous()\n        x = x.view(b*t, c, h, w)\n        # spatial max pooling\n        x = F.max_pool2d(x, x.size()[2:])\n        x = x.view(b, t, -1)\n        # temporal avg pooling\n        x = x.mean(1)\n        f = self.bn(x)\n\n        return f\n\n\ndef C2DResNet50(config, **kwargs):\n    c3d_idx = [[],[],[],[]]\n    nl_idx = [[],[],[],[]]\n\n    return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs)\n\n\ndef AP3DResNet50(config, **kwargs):\n    c3d_idx = [[],[0, 2],[0, 2, 4],[]]\n    nl_idx = [[],[],[],[]]\n\n    return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs)\n\n\ndef I3DResNet50(config, **kwargs):\n    c3d_idx = [[],[0, 2],[0, 2, 4],[]]\n    nl_idx = [[],[],[],[]]\n\n    return ResNet503D(config, c3d_blocks.I3D, c3d_idx, nl_idx, **kwargs)\n\n\ndef AP3DNLResNet50(config, **kwargs):\n    c3d_idx = [[],[0, 2],[0, 2, 4],[]]\n    nl_idx = [[],[1, 3],[1, 3, 5],[]]\n\n    return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs)\n\n\ndef NLResNet50(config, **kwargs):\n    c3d_idx = [[],[],[],[]]\n    nl_idx = [[],[1, 3],[1, 3, 5],[]]\n\n    return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs)\n"
  },
  {
    "path": "script.sh",
    "content": "# The code is builded with DistributedDataParallel. \r\n# Reprodecing the results in the paper should train the model on 2 GPUs.\r\n# You can also train this model on single GPU and double config.DATA.TRAIN_BATCH in configs.\r\n# For LTCC dataset\r\npython -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset ltcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 #\r\n# For PRCC dataset\r\npython -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset prcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 #\r\n# For VC-Clothes dataset. You should change the root path of '--resume' to your output path.\r\npython -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset vcclothes --cfg configs/res50_cels_cal.yaml --gpu 0,1 #\r\npython -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset vcclothes_cc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume /data/guxinqian/logs/vcclothes/res50-cels-cal/best_model.pth.tar #\r\npython -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset vcclothes_sc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --eval --resume /data/guxinqian/logs/vcclothes/res50-cels-cal/best_model.pth.tar #\r\n# For DeepChange dataset. Using amp can accelerate training.\r\npython -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset deepchange --cfg configs/res50_cels_cal_16x4.yaml --amp --gpu 0,1 #\r\n# For LaST dataset. Using amp can accelerate training.\r\npython -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset last --cfg configs/res50_cels_cal_tri_16x4.yaml --amp --gpu 0,1 #\r\n# For CCVID dataset\r\npython -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset ccvid --cfg configs/c2dres50_ce_cal.yaml --gpu 0,1 #"
  },
  {
    "path": "test.py",
    "content": "import time\nimport datetime\nimport logging\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import distributed as dist\nfrom tools.eval_metrics import evaluate, evaluate_with_clothes\n\n\nVID_DATASET = ['ccvid']\n\n\ndef concat_all_gather(tensors, num_total_examples):\n    '''\n    Performs all_gather operation on the provided tensor list.\n    '''\n    outputs = []\n    for tensor in tensors:\n        tensor = tensor.cuda()\n        tensors_gather = [tensor.clone() for _ in range(dist.get_world_size())]\n        dist.all_gather(tensors_gather, tensor)\n        output = torch.cat(tensors_gather, dim=0).cpu()\n        # truncate the dummy elements added by DistributedInferenceSampler\n        outputs.append(output[:num_total_examples])\n    return outputs\n\n\n@torch.no_grad()\ndef extract_img_feature(model, dataloader):\n    features, pids, camids, clothes_ids = [], torch.tensor([]), torch.tensor([]), torch.tensor([])\n    for batch_idx, (imgs, batch_pids, batch_camids, batch_clothes_ids) in enumerate(dataloader):\n        flip_imgs = torch.flip(imgs, [3])\n        imgs, flip_imgs = imgs.cuda(), flip_imgs.cuda()\n        batch_features = model(imgs)\n        batch_features_flip = model(flip_imgs)\n        batch_features += batch_features_flip\n        batch_features = F.normalize(batch_features, p=2, dim=1)\n\n        features.append(batch_features.cpu())\n        pids = torch.cat((pids, batch_pids.cpu()), dim=0)\n        camids = torch.cat((camids, batch_camids.cpu()), dim=0)\n        clothes_ids = torch.cat((clothes_ids, batch_clothes_ids.cpu()), dim=0)\n    features = torch.cat(features, 0)\n\n    return features, pids, camids, clothes_ids\n\n\n@torch.no_grad()\ndef extract_vid_feature(model, dataloader, vid2clip_index, data_length):\n    # In build_dataloader, each original test video is split into a series of equilong clips.\n    # During test, we first extact features for all clips\n    clip_features, clip_pids, clip_camids, clip_clothes_ids = [], torch.tensor([]), torch.tensor([]), torch.tensor([])\n    for batch_idx, (vids, batch_pids, batch_camids, batch_clothes_ids) in enumerate(dataloader):\n        if (batch_idx + 1) % 200==0:\n            logger.info(\"{}/{}\".format(batch_idx+1, len(dataloader)))\n        vids = vids.cuda()\n        batch_features = model(vids)\n        clip_features.append(batch_features.cpu())\n        clip_pids = torch.cat((clip_pids, batch_pids.cpu()), dim=0)\n        clip_camids = torch.cat((clip_camids, batch_camids.cpu()), dim=0)\n        clip_clothes_ids = torch.cat((clip_clothes_ids, batch_clothes_ids.cpu()), dim=0)\n    clip_features = torch.cat(clip_features, 0)\n\n    # Gather samples from different GPUs\n    clip_features, clip_pids, clip_camids, clip_clothes_ids = \\\n        concat_all_gather([clip_features, clip_pids, clip_camids, clip_clothes_ids], data_length)\n\n    # Use the averaged feature of all clips split from a video as the representation of this original full-length video\n    features = torch.zeros(len(vid2clip_index), clip_features.size(1)).cuda()\n    clip_features = clip_features.cuda()\n    pids = torch.zeros(len(vid2clip_index))\n    camids = torch.zeros(len(vid2clip_index))\n    clothes_ids = torch.zeros(len(vid2clip_index))\n    for i, idx in enumerate(vid2clip_index):\n        features[i] = clip_features[idx[0] : idx[1], :].mean(0)\n        features[i] = F.normalize(features[i], p=2, dim=0)\n        pids[i] = clip_pids[idx[0]]\n        camids[i] = clip_camids[idx[0]]\n        clothes_ids[i] = clip_clothes_ids[idx[0]]\n    features = features.cpu()\n\n    return features, pids, camids, clothes_ids\n\n\ndef test(config, model, queryloader, galleryloader, dataset):\n    logger = logging.getLogger('reid.test')\n    since = time.time()\n    model.eval()\n    local_rank = dist.get_rank()\n    # Extract features \n    if config.DATA.DATASET in VID_DATASET:\n        qf, q_pids, q_camids, q_clothes_ids = extract_vid_feature(model, queryloader, \n                                                                  dataset.query_vid2clip_index,\n                                                                  len(dataset.recombined_query))\n        gf, g_pids, g_camids, g_clothes_ids = extract_vid_feature(model, galleryloader, \n                                                                  dataset.gallery_vid2clip_index,\n                                                                  len(dataset.recombined_gallery))\n    else:\n        qf, q_pids, q_camids, q_clothes_ids = extract_img_feature(model, queryloader)\n        gf, g_pids, g_camids, g_clothes_ids = extract_img_feature(model, galleryloader)\n        # Gather samples from different GPUs\n        torch.cuda.empty_cache()\n        qf, q_pids, q_camids, q_clothes_ids = concat_all_gather([qf, q_pids, q_camids, q_clothes_ids], len(dataset.query))\n        gf, g_pids, g_camids, g_clothes_ids = concat_all_gather([gf, g_pids, g_camids, g_clothes_ids], len(dataset.gallery))\n    torch.cuda.empty_cache()\n    time_elapsed = time.time() - since\n    \n    logger.info(\"Extracted features for query set, obtained {} matrix\".format(qf.shape))    \n    logger.info(\"Extracted features for gallery set, obtained {} matrix\".format(gf.shape))\n    logger.info('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n    # Compute distance matrix between query and gallery\n    since = time.time()\n    m, n = qf.size(0), gf.size(0)\n    distmat = torch.zeros((m,n))\n    qf, gf = qf.cuda(), gf.cuda()\n    # Cosine similarity\n    for i in range(m):\n        distmat[i] = (- torch.mm(qf[i:i+1], gf.t())).cpu()\n    distmat = distmat.numpy()\n    q_pids, q_camids, q_clothes_ids = q_pids.numpy(), q_camids.numpy(), q_clothes_ids.numpy()\n    g_pids, g_camids, g_clothes_ids = g_pids.numpy(), g_camids.numpy(), g_clothes_ids.numpy()\n    time_elapsed = time.time() - since\n    logger.info('Distance computing in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n\n    since = time.time()\n    logger.info(\"Computing CMC and mAP\")\n    cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids)\n    logger.info(\"Results ---------------------------------------------------\")\n    logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))\n    logger.info(\"-----------------------------------------------------------\")\n    time_elapsed = time.time() - since\n    logger.info('Using {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n\n    if config.DATA.DATASET in ['last', 'deepchange', 'vcclothes_sc', 'vcclothes_cc']: return cmc[0]\n\n    logger.info(\"Computing CMC and mAP only for the same clothes setting\")\n    cmc, mAP = evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothes_ids, g_clothes_ids, mode='SC')\n    logger.info(\"Results ---------------------------------------------------\")\n    logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))\n    logger.info(\"-----------------------------------------------------------\")\n\n    logger.info(\"Computing CMC and mAP only for clothes-changing\")\n    cmc, mAP = evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothes_ids, g_clothes_ids, mode='CC')\n    logger.info(\"Results ---------------------------------------------------\")\n    logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))\n    logger.info(\"-----------------------------------------------------------\")\n\n    return cmc[0]\n\n\ndef test_prcc(model, queryloader_same, queryloader_diff, galleryloader, dataset):\n    logger = logging.getLogger('reid.test')\n    since = time.time()\n    model.eval()\n    local_rank = dist.get_rank()\n    # Extract features for query set\n    qsf, qs_pids, qs_camids, qs_clothes_ids = extract_img_feature(model, queryloader_same)\n    qdf, qd_pids, qd_camids, qd_clothes_ids = extract_img_feature(model, queryloader_diff)\n    # Extract features for gallery set\n    gf, g_pids, g_camids, g_clothes_ids = extract_img_feature(model, galleryloader)\n    # Gather samples from different GPUs\n    torch.cuda.empty_cache()\n    qsf, qs_pids, qs_camids, qs_clothes_ids = concat_all_gather([qsf, qs_pids, qs_camids, qs_clothes_ids], len(dataset.query_same))\n    qdf, qd_pids, qd_camids, qd_clothes_ids = concat_all_gather([qdf, qd_pids, qd_camids, qd_clothes_ids], len(dataset.query_diff))\n    gf, g_pids, g_camids, g_clothes_ids = concat_all_gather([gf, g_pids, g_camids, g_clothes_ids], len(dataset.gallery))\n    time_elapsed = time.time() - since\n    \n    logger.info(\"Extracted features for query set (with same clothes), obtained {} matrix\".format(qsf.shape))\n    logger.info(\"Extracted features for query set (with different clothes), obtained {} matrix\".format(qdf.shape))\n    logger.info(\"Extracted features for gallery set, obtained {} matrix\".format(gf.shape))\n    logger.info('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))\n    # Compute distance matrix between query and gallery\n    m, n, k = qsf.size(0), qdf.size(0), gf.size(0)\n    distmat_same = torch.zeros((m, k))\n    distmat_diff = torch.zeros((n, k))\n    qsf, qdf, gf = qsf.cuda(), qdf.cuda(), gf.cuda()\n    # Cosine similarity\n    for i in range(m):\n        distmat_same[i] = (- torch.mm(qsf[i:i+1], gf.t())).cpu()\n    for i in range(n):\n        distmat_diff[i] = (- torch.mm(qdf[i:i+1], gf.t())).cpu()\n    distmat_same = distmat_same.numpy()\n    distmat_diff = distmat_diff.numpy()\n    qs_pids, qs_camids, qs_clothes_ids = qs_pids.numpy(), qs_camids.numpy(), qs_clothes_ids.numpy()\n    qd_pids, qd_camids, qd_clothes_ids = qd_pids.numpy(), qd_camids.numpy(), qd_clothes_ids.numpy()\n    g_pids, g_camids, g_clothes_ids = g_pids.numpy(), g_camids.numpy(), g_clothes_ids.numpy()\n\n    logger.info(\"Computing CMC and mAP for the same clothes setting\")\n    cmc, mAP = evaluate(distmat_same, qs_pids, g_pids, qs_camids, g_camids)\n    logger.info(\"Results ---------------------------------------------------\")\n    logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))\n    logger.info(\"-----------------------------------------------------------\")\n\n    logger.info(\"Computing CMC and mAP only for clothes changing\")\n    cmc, mAP = evaluate(distmat_diff, qd_pids, g_pids, qd_camids, g_camids)\n    logger.info(\"Results ---------------------------------------------------\")\n    logger.info('top1:{:.1%} top5:{:.1%} top10:{:.1%} top20:{:.1%} mAP:{:.1%}'.format(cmc[0], cmc[4], cmc[9], cmc[19], mAP))\n    logger.info(\"-----------------------------------------------------------\")\n\n    return cmc[0]"
  },
  {
    "path": "tools/eval_metrics.py",
    "content": "import logging\nimport numpy as np\n\n\ndef compute_ap_cmc(index, good_index, junk_index):\n    \"\"\" Compute AP and CMC for each sample\n    \"\"\"\n    ap = 0\n    cmc = np.zeros(len(index)) \n    \n    # remove junk_index\n    mask = np.in1d(index, junk_index, invert=True)\n    index = index[mask]\n\n    # find good_index index\n    ngood = len(good_index)\n    mask = np.in1d(index, good_index)\n    rows_good = np.argwhere(mask==True)\n    rows_good = rows_good.flatten()\n    \n    cmc[rows_good[0]:] = 1.0\n    for i in range(ngood):\n        d_recall = 1.0/ngood\n        precision = (i+1)*1.0/(rows_good[i]+1)\n        ap = ap + d_recall*precision\n\n    return ap, cmc\n\n\ndef evaluate(distmat, q_pids, g_pids, q_camids, g_camids):\n    \"\"\" Compute CMC and mAP\n\n    Args:\n        distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery).\n        q_pids (numpy array): person IDs for query samples.\n        g_pids (numpy array): person IDs for gallery samples.\n        q_camids (numpy array): camera IDs for query samples.\n        g_camids (numpy array): camera IDs for gallery samples.\n    \"\"\"\n    num_q, num_g = distmat.shape\n    index = np.argsort(distmat, axis=1) # from small to large\n\n    num_no_gt = 0 # num of query imgs without groundtruth\n    num_r1 = 0\n    CMC = np.zeros(len(g_pids))\n    AP = 0\n\n    for i in range(num_q):\n        # groundtruth index\n        query_index = np.argwhere(g_pids==q_pids[i])\n        camera_index = np.argwhere(g_camids==q_camids[i])\n        good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)\n        if good_index.size == 0:\n            num_no_gt += 1\n            continue\n        # remove gallery samples that have the same pid and camid with query\n        junk_index = np.intersect1d(query_index, camera_index)\n\n        ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index)\n        if CMC_tmp[0]==1:\n            num_r1 += 1\n        CMC = CMC + CMC_tmp\n        AP += ap_tmp\n\n    if num_no_gt > 0:\n        logger = logging.getLogger('reid.evaluate')\n        logger.info(\"{} query samples do not have groundtruth.\".format(num_no_gt))\n\n    CMC = CMC / (num_q - num_no_gt)\n    mAP = AP / (num_q - num_no_gt)\n\n    return CMC, mAP\n\n\ndef evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothids, g_clothids, mode='CC'):\n    \"\"\" Compute CMC and mAP with clothes\n\n    Args:\n        distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery).\n        q_pids (numpy array): person IDs for query samples.\n        g_pids (numpy array): person IDs for gallery samples.\n        q_camids (numpy array): camera IDs for query samples.\n        g_camids (numpy array): camera IDs for gallery samples.\n        q_clothids (numpy array): clothes IDs for query samples.\n        g_clothids (numpy array): clothes IDs for gallery samples.\n        mode: 'CC' for clothes-changing; 'SC' for the same clothes.\n    \"\"\"\n    assert mode in ['CC', 'SC']\n    \n    num_q, num_g = distmat.shape\n    index = np.argsort(distmat, axis=1) # from small to large\n\n    num_no_gt = 0 # num of query imgs without groundtruth\n    num_r1 = 0\n    CMC = np.zeros(len(g_pids))\n    AP = 0\n\n    for i in range(num_q):\n        # groundtruth index\n        query_index = np.argwhere(g_pids==q_pids[i])\n        camera_index = np.argwhere(g_camids==q_camids[i])\n        cloth_index = np.argwhere(g_clothids==q_clothids[i])\n        good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)\n        if mode == 'CC':\n            good_index = np.setdiff1d(good_index, cloth_index, assume_unique=True)\n            # remove gallery samples that have the same (pid, camid) or (pid, clothid) with query\n            junk_index1 = np.intersect1d(query_index, camera_index)\n            junk_index2 = np.intersect1d(query_index, cloth_index)\n            junk_index = np.union1d(junk_index1, junk_index2)\n        else:\n            good_index = np.intersect1d(good_index, cloth_index)\n            # remove gallery samples that have the same (pid, camid) or \n            # (the same pid and different clothid) with query\n            junk_index1 = np.intersect1d(query_index, camera_index)\n            junk_index2 = np.setdiff1d(query_index, cloth_index)\n            junk_index = np.union1d(junk_index1, junk_index2)\n\n        if good_index.size == 0:\n            num_no_gt += 1\n            continue\n    \n        ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index)\n        if CMC_tmp[0]==1:\n            num_r1 += 1\n        CMC = CMC + CMC_tmp\n        AP += ap_tmp\n\n    if num_no_gt > 0:\n        logger = logging.getLogger('reid.evaluate')\n        logger.info(\"{} query samples do not have groundtruth.\".format(num_no_gt))\n\n    if (num_q - num_no_gt) != 0:\n        CMC = CMC / (num_q - num_no_gt)\n        mAP = AP / (num_q - num_no_gt)\n    else:\n        mAP = 0\n\n    return CMC, mAP"
  },
  {
    "path": "tools/utils.py",
    "content": "import os\nimport sys\nimport shutil\nimport errno\nimport json\nimport os.path as osp\nimport torch\nimport random\nimport logging\nimport numpy as np\n\n\ndef set_seed(seed=None):\n    if seed is None:\n        return\n    random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = (\"%s\" % seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\ndef mkdir_if_missing(directory):\n    if not osp.exists(directory):\n        try:\n            os.makedirs(directory)\n        except OSError as e:\n            if e.errno != errno.EEXIST:\n                raise\n\n\ndef read_json(fpath):\n    with open(fpath, 'r') as f:\n        obj = json.load(f)\n    return obj\n\n\ndef write_json(obj, fpath):\n    mkdir_if_missing(osp.dirname(fpath))\n    with open(fpath, 'w') as f:\n        json.dump(obj, f, indent=4, separators=(',', ': '))\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value.\n       \n       Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262\n    \"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\ndef save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):\n    mkdir_if_missing(osp.dirname(fpath))\n    torch.save(state, fpath)\n    if is_best:\n        shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))\n\n'''\nclass Logger(object):\n    \"\"\"\n    Write console output to external text file.\n    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.\n    \"\"\"\n    def __init__(self, fpath=None):\n        self.console = sys.stdout\n        self.file = None\n        if fpath is not None:\n            mkdir_if_missing(os.path.dirname(fpath))\n            self.file = open(fpath, 'w')\n\n    def __del__(self):\n        self.close()\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, *args):\n        self.close()\n\n    def write(self, msg):\n        self.console.write(msg)\n        if self.file is not None:\n            self.file.write(msg)\n\n    def flush(self):\n        self.console.flush()\n        if self.file is not None:\n            self.file.flush()\n            os.fsync(self.file.fileno())\n\n    def close(self):\n        self.console.close()\n        if self.file is not None:\n            self.file.close()\n'''\n\n\ndef get_logger(fpath, local_rank=0, name=''):\n    # Creat logger\n    logger = logging.getLogger(name)\n    level = logging.INFO if local_rank in [-1, 0] else logging.WARN\n    logger.setLevel(level=level)\n\n    # Output to console\n    console_handler = logging.StreamHandler(sys.stdout)\n    console_handler.setLevel(level=level) \n    console_handler.setFormatter(logging.Formatter('%(message)s'))\n    logger.addHandler(console_handler)\n\n    # Output to file\n    if fpath is not None:\n            mkdir_if_missing(os.path.dirname(fpath))\n    file_handler = logging.FileHandler(fpath, mode='w')\n    file_handler.setLevel(level=level)\n    file_handler.setFormatter(logging.Formatter('%(message)s'))\n    logger.addHandler(file_handler)\n\n    return logger"
  },
  {
    "path": "train.py",
    "content": "import time\nimport datetime\nimport logging\nimport torch\nfrom apex import amp\nfrom tools.utils import AverageMeter\n\n\ndef train_cal(config, epoch, model, classifier, clothes_classifier, criterion_cla, criterion_pair, \n    criterion_clothes, criterion_adv, optimizer, optimizer_cc, trainloader, pid2clothes):\n    logger = logging.getLogger('reid.train')\n    batch_cla_loss = AverageMeter()\n    batch_pair_loss = AverageMeter()\n    batch_clo_loss = AverageMeter()\n    batch_adv_loss = AverageMeter()\n    corrects = AverageMeter()\n    clothes_corrects = AverageMeter()\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    model.train()\n    classifier.train()\n    clothes_classifier.train()\n\n    end = time.time()\n    for batch_idx, (imgs, pids, camids, clothes_ids) in enumerate(trainloader):\n        # Get all positive clothes classes (belonging to the same identity) for each sample\n        pos_mask = pid2clothes[pids]\n        imgs, pids, clothes_ids, pos_mask = imgs.cuda(), pids.cuda(), clothes_ids.cuda(), pos_mask.float().cuda()\n        # Measure data loading time\n        data_time.update(time.time() - end)\n        # Forward\n        features = model(imgs)\n        outputs = classifier(features)\n        pred_clothes = clothes_classifier(features.detach())\n        _, preds = torch.max(outputs.data, 1)\n\n        # Update the clothes discriminator\n        clothes_loss = criterion_clothes(pred_clothes, clothes_ids)\n        if epoch >= config.TRAIN.START_EPOCH_CC:\n            optimizer_cc.zero_grad()\n            if config.TRAIN.AMP:\n                with amp.scale_loss(clothes_loss, optimizer_cc) as scaled_loss:\n                    scaled_loss.backward()\n            else:\n                clothes_loss.backward()\n            optimizer_cc.step()\n\n        # Update the backbone\n        new_pred_clothes = clothes_classifier(features)\n        _, clothes_preds = torch.max(new_pred_clothes.data, 1)\n\n        # Compute loss\n        cla_loss = criterion_cla(outputs, pids)\n        pair_loss = criterion_pair(features, pids)\n        adv_loss = criterion_adv(new_pred_clothes, clothes_ids, pos_mask)\n        if epoch >= config.TRAIN.START_EPOCH_ADV:\n            loss = cla_loss + adv_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss   \n        else:\n            loss = cla_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss   \n        optimizer.zero_grad()\n        if config.TRAIN.AMP:\n            with amp.scale_loss(loss, optimizer) as scaled_loss:\n                scaled_loss.backward()\n        else:\n            loss.backward()\n        optimizer.step()\n\n        # statistics\n        corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0))\n        clothes_corrects.update(torch.sum(clothes_preds == clothes_ids.data).float()/clothes_ids.size(0), clothes_ids.size(0))\n        batch_cla_loss.update(cla_loss.item(), pids.size(0))\n        batch_pair_loss.update(pair_loss.item(), pids.size(0))\n        batch_clo_loss.update(clothes_loss.item(), clothes_ids.size(0))\n        batch_adv_loss.update(adv_loss.item(), clothes_ids.size(0))\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n    logger.info('Epoch{0} '\n                  'Time:{batch_time.sum:.1f}s '\n                  'Data:{data_time.sum:.1f}s '\n                  'ClaLoss:{cla_loss.avg:.4f} '\n                  'PairLoss:{pair_loss.avg:.4f} '\n                  'CloLoss:{clo_loss.avg:.4f} '\n                  'AdvLoss:{adv_loss.avg:.4f} '\n                  'Acc:{acc.avg:.2%} '\n                  'CloAcc:{clo_acc.avg:.2%} '.format(\n                   epoch+1, batch_time=batch_time, data_time=data_time, \n                   cla_loss=batch_cla_loss, pair_loss=batch_pair_loss, \n                   clo_loss=batch_clo_loss, adv_loss=batch_adv_loss, \n                   acc=corrects, clo_acc=clothes_corrects))\n\n\ndef train_cal_with_memory(config, epoch, model, classifier, criterion_cla, criterion_pair, \n    criterion_adv, optimizer, trainloader, pid2clothes):\n    logger = logging.getLogger('reid.train')\n    batch_cla_loss = AverageMeter()\n    batch_pair_loss = AverageMeter()\n    batch_adv_loss = AverageMeter()\n    corrects = AverageMeter()\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    model.train()\n    classifier.train()\n\n    end = time.time()\n    for batch_idx, (imgs, pids, camids, clothes_ids) in enumerate(trainloader):\n        # Get all positive clothes classes (belonging to the same identity) for each sample\n        pos_mask = pid2clothes[pids]\n        imgs, pids, clothes_ids, pos_mask = imgs.cuda(), pids.cuda(), clothes_ids.cuda(), pos_mask.float().cuda()\n        # Measure data loading time\n        data_time.update(time.time() - end)\n        # Forward\n        features = model(imgs)\n        outputs = classifier(features)\n        _, preds = torch.max(outputs.data, 1)\n\n        # Compute loss\n        cla_loss = criterion_cla(outputs, pids)\n        pair_loss = criterion_pair(features, pids)\n\n        if epoch >= config.TRAIN.START_EPOCH_ADV:\n            adv_loss = criterion_adv(features, clothes_ids, pos_mask)\n            loss = cla_loss + adv_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss   \n        else:\n            loss = cla_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss  \n\n        optimizer.zero_grad()\n        if config.TRAIN.AMP:\n            with amp.scale_loss(loss, optimizer) as scaled_loss:\n                scaled_loss.backward()\n        else:\n            loss.backward()\n        optimizer.step()\n\n        # statistics\n        corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0))\n        batch_cla_loss.update(cla_loss.item(), pids.size(0))\n        batch_pair_loss.update(pair_loss.item(), pids.size(0))\n        if epoch >= config.TRAIN.START_EPOCH_ADV: \n            batch_adv_loss.update(adv_loss.item(), clothes_ids.size(0))\n        # measure elapsed time\n        batch_time.update(time.time() - end)\n        end = time.time()\n\n    logger.info('Epoch{0} '\n                'Time:{batch_time.sum:.1f}s '\n                'Data:{data_time.sum:.1f}s '\n                'ClaLoss:{cla_loss.avg:.4f} '\n                'PairLoss:{pair_loss.avg:.4f} '\n                'AdvLoss:{adv_loss.avg:.4f} '\n                'Acc:{acc.avg:.2%} '.format(\n                epoch+1, batch_time=batch_time, data_time=data_time, \n                cla_loss=batch_cla_loss, pair_loss=batch_pair_loss, \n                adv_loss=batch_adv_loss, acc=corrects))"
  }
]