[
  {
    "path": "README.md",
    "content": "# HHCL-ReID ![visitors](https://visitor-badge.glitch.me/badge?page_id=bupt-ai-cz.HHCL-ReID)\n[![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Codes%20for%20Our%20Paper:%20\"Hard-sample%20Guided%20Hybrid%20Contrast%20Learning%20for%20Unsupervised%20PersonRe-Identification\"%20&url=https://github.com/bupt-ai-cz/HHCL-ReID) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hard-sample-guided-hybrid-contrast-learning/unsupervised-person-re-identification-on-5)](https://paperswithcode.com/sota/unsupervised-person-re-identification-on-5?p=hard-sample-guided-hybrid-contrast-learning)  [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hard-sample-guided-hybrid-contrast-learning/unsupervised-person-re-identification-on-4)](https://paperswithcode.com/sota/unsupervised-person-re-identification-on-4?p=hard-sample-guided-hybrid-contrast-learning)\n\nThis repository is the official implementation of our paper \"[Hard-sample Guided Hybrid Contrast Learning for Unsupervised Person Re-Identification](https://arxiv.org/abs/2109.12333)!\".  \n\n![framework_HCCL](img/framework_HCCL.jpg)\n\n## Requirements\n\n---\n\n    git clone https://github.com/bupt-ai-cz/HHCL-ReID.git\n    cd HHCL-ReID\n    pip install -r requirements.txt\n    python setup.py develop\n\n## Prepare Datasets\n\n---\n\nDownload the datasets Market-1501,MSMT17,DukeMTMC-reID from this [link](https://drive.google.com/file/d/19oWiYGjTgouFMK_psZvH8ysDGQ1KUbk-/view?usp=sharing) and unzip them under the directory like:\n\n    HHCL-ReID/examples/data\n    ├── market1501\n    │   └── Market-1501-v15.09.15\n    └── dukemtmcreid\n        └── DukeMTMC-reID\n\nPrepare ImageNet Pre-trained Models for IBN-Net\n\nWhen training with the backbone of [IBN-ResNet](https://arxiv.org/abs/1807.09441), you need to download the ImageNet-pretrained model from this [link](https://drive.google.com/drive/folders/1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S) and save it under the path of `examples/pretrained/`.\n\n```\nHHCL-ReID/examples\n└── pretrained\n    └── resnet50_ibn_a.pth.tar\n```\n\n## Training\n\n---\n\nWe utilize 4 GTX-2080TI GPUs for training. Examples:\n\nMarket-1501:\n\n    CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/train.py -b 256 -a resnet50 -d market1501 --iters 200 --eps 0.45 --momentum 0.1 --num-instances 16 --pooling-type avg --memorybank CMhybrid --epochs 60 --logs-dir examples/logs/market1501/resnet50_avg_cmhybrid\n    \n\nDukeMTMC-reID:\n\n    \n    CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/train.py -b 256 -a resnet50 -d dukemtmcreid --iters 200 --eps 0.6 --momentum 0.1 --num-instances 16 --pooling-type avg --memorybank CMhybrid --epochs 60 --logs-dir examples/logs/dukemtmcreid/resnet50_avg_cmhybrid\n\n- use `-a resnet50` (default) for the backbone of ResNet-50, and `-a resnet_ibn50a` for the backbone of IBN-ResNet;\n- use `--pooling-type gem` for Generalized Mean Pooling (GEM) pooling and `--smooth` for label smoothing. \n\n## Evaluation\n\n---\n\nTo evaluate my model on ImageNet, run:\n\n    CUDA_VISIBLE_DEVICES=0 python examples/test.py -d $DATASET --resume $PATH --pooling-type avg\n\n## Results\n\n---\n\nOur model achieves the following performance on :\n\n| Dataset            | Market1501 |      |      |      | DukeMTMC-reID |      |      |      |\n| ------------------ | ---------- | ---- | ---- | ---- | ------------- | ---- | ---- | ---- |\n| Setting            | mAP        | R1   | R5   | R10  | mAP           | R1   | R5   | R10  |\n| Fully Unsupervised | 84.2       | 93.4 | 97.7 | 98.5 | 73.3          | 85.1 | 92.4 | 94.6 |\n| Supervised         | 87.2       | 94.6 | 98.5 | 99.1 | 80.0          | 89.8 | 95.2 | 96.7 |\n\nYou can download the above models in the paper from [Google Drive](https://drive.google.com/drive/folders/1WQw7wD2Mu_1SKl07_NdKvrYf2xrs3CEZ) \n\n## Citation\n\n---\n\nIf you find this code useful for your research, please cite our paper\n\n```\n@article{hu2021hard,\n  title={Hard-sample Guided Hybrid Contrast Learning for Unsupervised Person Re-Identification},\n  author={Hu, Zheng and Zhu, Chuang and He, Gang},\n  journal={arXiv preprint arXiv:2109.12333},\n  year={2021}\n}\n```\n\n## Acknowledgements\n---\n\nThis project is not possible without multiple great opensourced codebases. We list them below.\n\n- [SpCL](https://github.com/yxgeee/SpCL)\n- [cluster-contrast-reid](https://github.com/alibaba/cluster-contrast-reid)\n"
  },
  {
    "path": "examples/test.py",
    "content": "from __future__ import print_function, absolute_import\nimport argparse\nimport os.path as osp\nimport random\nimport numpy as np\nimport sys\n\nimport torch\nfrom torch import nn\nfrom torch.backends import cudnn\nfrom torch.utils.data import DataLoader\n\nfrom hhcl import datasets\nfrom hhcl import models\nfrom hhcl.models.dsbn import convert_dsbn, convert_bn\nfrom hhcl.evaluators import Evaluator\nfrom hhcl.utils.data import transforms as T\nfrom hhcl.utils.data.preprocessor import Preprocessor\nfrom hhcl.utils.logging import Logger\nfrom hhcl.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict\n\n\ndef get_data(name, data_dir, height, width, batch_size, workers):\n    root = osp.join(data_dir, name)\n\n    dataset = datasets.create(name, root)\n\n    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],\n                             std=[0.229, 0.224, 0.225])\n\n    test_transformer = T.Compose([\n             T.Resize((height, width), interpolation=3),\n             T.ToTensor(),\n             normalizer\n         ])\n\n    test_loader = DataLoader(\n        Preprocessor(list(set(dataset.query) | set(dataset.gallery)),\n                     root=dataset.images_dir, transform=test_transformer),\n        batch_size=batch_size, num_workers=workers,\n        shuffle=False, pin_memory=True)\n    return dataset, test_loader\n\n\ndef main():\n    args = parser.parse_args()\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        np.random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n\n    main_worker(args)\n\n\ndef main_worker(args):\n    cudnn.benchmark = True\n\n    log_dir = osp.dirname(args.resume)\n    sys.stdout = Logger(osp.join(log_dir, 'log_test.txt'))\n    print(\"==========\\nArgs:{}\\n==========\".format(args))\n\n    # Create data loaders\n    dataset, test_loader = get_data(args.dataset, args.data_dir, args.height,\n                                    args.width, args.batch_size, args.workers)\n\n    # Create model\n    model = models.create(args.arch, pretrained=False, num_features=args.features, dropout=args.dropout,\n                          num_classes=0, pooling_type=args.pooling_type)\n    if args.dsbn:\n        print(\"==> Load the model with domain-specific BNs\")\n        convert_dsbn(model)\n\n    # Load from checkpoint\n    checkpoint = load_checkpoint(args.resume)\n    copy_state_dict(checkpoint['state_dict'], model, strip='module.')\n\n    if args.dsbn:\n        print(\"==> Test with {}-domain BNs\".format(\"source\" if args.test_source else \"target\"))\n        convert_bn(model, use_target=(not args.test_source))\n\n    model.cuda()\n    model = nn.DataParallel(model)\n\n    # Evaluator\n    model.eval()\n    evaluator = Evaluator(model)\n    evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True, rerank=args.rerank)\n    return\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description=\"Testing the model\")\n    # data\n    parser.add_argument('-d', '--dataset', type=str, default='market1501')\n    parser.add_argument('-b', '--batch-size', type=int, default=256)\n    parser.add_argument('-j', '--workers', type=int, default=4)\n    parser.add_argument('--height', type=int, default=256, help=\"input height\")\n    parser.add_argument('--width', type=int, default=128, help=\"input width\")\n    # model\n    parser.add_argument('-a', '--arch', type=str, default='resnet50',\n                        choices=models.names())\n    parser.add_argument('--features', type=int, default=0)\n    parser.add_argument('--dropout', type=float, default=0)\n\n    parser.add_argument('--resume', type=str,\n                        default=\"examples/logs/market1501/resnet50_avg/model_best.pth.tar\",\n                        metavar='PATH')\n    # testing configs\n    parser.add_argument('--rerank', action='store_true',\n                        help=\"evaluation only\")\n    parser.add_argument('--dsbn', action='store_true',\n                        help=\"test on the model with domain-specific BN\")\n    parser.add_argument('--test-source', action='store_true',\n                        help=\"test on the source domain\")\n    parser.add_argument('--seed', type=int, default=1)\n    # path\n    working_dir = osp.dirname(osp.abspath(__file__))\n    parser.add_argument('--data-dir', type=str, metavar='PATH',\n                        default='examples/data')\n    parser.add_argument('--pooling-type', type=str, default='avg')\n    parser.add_argument('--embedding_features_path', type=str,\n                        default='examples/logs/market1501/resnet50_avg/')\n    main()\n"
  },
  {
    "path": "examples/train.py",
    "content": "# -*- coding: utf-8 -*-\nfrom __future__ import print_function, absolute_import\nimport argparse\nimport os.path as osp\nimport random\nimport numpy as np\nimport sys\nimport collections\nimport time\nfrom datetime import timedelta\n\nfrom sklearn.cluster import DBSCAN\n\nimport torch\nfrom torch import nn\nfrom torch.backends import cudnn\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\nfrom hhcl import datasets\nfrom hhcl import models\nfrom hhcl.models.cm import ClusterMemory\nfrom hhcl.trainers import Trainer\nfrom hhcl.evaluators import Evaluator, extract_features\nfrom hhcl.utils.data import IterLoader\nfrom hhcl.utils.data import transforms as T\nfrom hhcl.utils.data.sampler import RandomMultipleGallerySampler\nfrom hhcl.utils.data.preprocessor import Preprocessor\nfrom hhcl.utils.logging import Logger\nfrom hhcl.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict\nfrom hhcl.utils.faiss_rerank import compute_jaccard_distance\n\nstart_epoch = best_mAP = 0\n\n\ndef get_data(name, data_dir):\n    root = osp.join(data_dir, name)\n    dataset = datasets.create(name, root)\n    return dataset\n\n\ndef get_train_loader(args, dataset, height, width, batch_size, workers,\n                     num_instances, iters, trainset=None):\n\n    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],\n                             std=[0.229, 0.224, 0.225])\n\n    train_transformer = T.Compose([\n        T.Resize((height, width), interpolation=3),\n        T.RandomHorizontalFlip(p=0.5),\n        T.Pad(10),\n        T.RandomCrop((height, width)),\n        T.ToTensor(),\n        normalizer,\n        T.RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406])\n    ])\n\n    train_set = sorted(dataset.train) if trainset is None else sorted(trainset)\n    rmgs_flag = num_instances > 0\n    if rmgs_flag:\n        sampler = RandomMultipleGallerySampler(train_set, num_instances)\n    else:\n        sampler = None\n    train_loader = IterLoader(\n        DataLoader(Preprocessor(train_set, root=dataset.images_dir, transform=train_transformer),\n                   batch_size=batch_size, num_workers=workers, sampler=sampler,\n                   shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=iters)\n\n    return train_loader\n\n\ndef get_test_loader(dataset, height, width, batch_size, workers, testset=None):\n    normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],\n                             std=[0.229, 0.224, 0.225])\n\n    test_transformer = T.Compose([\n        T.Resize((height, width), interpolation=3),\n        T.ToTensor(),\n        normalizer\n    ])\n\n    if testset is None:\n        testset = list(set(dataset.query) | set(dataset.gallery))\n\n    test_loader = DataLoader(\n        Preprocessor(testset, root=dataset.images_dir, transform=test_transformer),\n        batch_size=batch_size, num_workers=workers,\n        shuffle=False, pin_memory=True)\n\n    return test_loader\n\n\ndef create_model(args):\n    model = models.create(args.arch, num_features=args.features, norm=True, dropout=args.dropout,\n                          num_classes=0, pooling_type=args.pooling_type)\n    \n    # Load from checkpoint\n    if args.resume:\n        global start_epoch\n        checkpoint = load_checkpoint(args.resume)\n        copy_state_dict(checkpoint['state_dict'], model, strip='module.')\n        start_epoch = checkpoint['epoch']\n    \n    # use CUDA\n    model.cuda()\n    model = nn.DataParallel(model)\n    return model\n\n\ndef main():\n    args = parser.parse_args()\n\n    if args.seed is not None:\n        random.seed(args.seed)\n        np.random.seed(args.seed)\n        torch.manual_seed(args.seed)\n        cudnn.deterministic = True\n\n    main_worker(args)\n\n\ndef main_worker(args):\n    global start_epoch, best_mAP\n    start_time = time.monotonic()\n\n    cudnn.benchmark = True\n\n    sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))\n    print(\"==========\\nArgs:{}\\n==========\".format(args))\n\n    # Create datasets\n    iters = args.iters if (args.iters > 0) else None\n    print(\"==> Load unlabeled dataset\")\n    dataset = get_data(args.dataset, args.data_dir)\n    test_loader = get_test_loader(dataset, args.height, args.width, args.batch_size, args.workers)\n\n    # Create model\n    model = create_model(args)\n\n    # Evaluator\n    evaluator = Evaluator(model)\n\n    # Optimizer\n    params = [{\"params\": [value]} for _, value in model.named_parameters() if value.requires_grad]\n    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)\n    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)\n\n    # Trainer\n    trainer = Trainer(model)\n\n    for epoch in range(start_epoch, args.epochs):\n        with torch.no_grad():\n            print('==> Create pseudo labels for unlabeled data')\n            cluster_loader = get_test_loader(dataset, args.height, args.width,\n                                             args.batch_size, args.workers, testset=sorted(dataset.train))\n\n            features, _ = extract_features(model, cluster_loader, print_freq=50)\n            features = torch.cat([features[f].unsqueeze(0) for f, _, _ in sorted(dataset.train)], 0)\n            rerank_dist = compute_jaccard_distance(features, k1=args.k1, k2=args.k2)\n\n            if epoch == start_epoch:\n                # DBSCAN cluster\n                eps = args.eps\n                print('Clustering criterion eps: {:.3f}'.format(eps))\n                cluster = DBSCAN(eps=eps, min_samples=4, metric='precomputed', n_jobs=-1)\n\n            # select & cluster images as training set of this epochs\n            pseudo_labels = cluster.fit_predict(rerank_dist)\n            num_cluster = len(set(pseudo_labels)) - (1 if -1 in pseudo_labels else 0)\n\n        # generate new dataset and calculate cluster centers\n        @torch.no_grad()\n        def generate_cluster_features(labels, features):\n            centers = collections.defaultdict(list)\n            for i, label in enumerate(labels):\n                if label == -1:\n                    continue\n                centers[labels[i]].append(features[i])\n\n            centers = [\n                torch.stack(centers[idx], dim=0).mean(0) for idx in sorted(centers.keys())\n            ]\n\n            centers = torch.stack(centers, dim=0)\n            return centers\n\n        cluster_features = generate_cluster_features(pseudo_labels, features)\n        \n        def generate_random_features(labels, features, num_cluster, num_instances):\n            indexes = np.zeros(num_cluster*num_instances)\n            for i in range(num_cluster):\n                index = [i+k*num_cluster for k in range(num_instances)]\n                samples = np.random.choice(np.where(pseudo_labels==i)[0], num_instances, True)\n                indexes[index] = samples\n            memory_features = features[indexes]\n            return memory_features\n\n        if args.memorybank=='CMhybrid_v2':\n            memory_features = generate_random_features(pseudo_labels, features, num_cluster, args.num_instances)\n            mask = (pseudo_labels < 0).astype(int)\n            print('==> Statistics for outliers with pseudo labels. outliers/total = {}/{} = {:.3f}'.format(mask.sum(), pseudo_labels.size, mask.sum()/pseudo_labels.size))\n            \n        del cluster_loader, features\n\n        # Create memory bank\n        memory = ClusterMemory(model.module.num_features, num_cluster, temp=args.temp,\n                                momentum=args.momentum, mode=args.memorybank, smooth=args.smooth,\n                                num_instances=args.num_instances).cuda()\n        if args.memorybank=='CMhybrid':\n            memory.features = F.normalize(cluster_features.repeat(2, 1), dim=1).cuda()\n        elif args.memorybank=='CMhybrid_v2':\n            memory.features = F.normalize(torch.cat([cluster_features, memory_features],dim=0), dim=1).cuda()\n        else:\n            memory.features = F.normalize(cluster_features, dim=1).cuda()\n\n        trainer.memory = memory\n\n        pseudo_labeled_dataset = []\n        for i, ((fname, _, cid), label) in enumerate(zip(sorted(dataset.train), pseudo_labels)):\n            if label != -1:\n                pseudo_labeled_dataset.append((fname, label.item(), cid))\n\n        print('==> Statistics for epoch {}: {} clusters'.format(epoch, num_cluster))\n\n        train_loader = get_train_loader(args, dataset, args.height, args.width,\n                                        args.batch_size, args.workers, args.num_instances, iters,\n                                        trainset=pseudo_labeled_dataset)\n\n        train_loader.new_epoch()\n\n        trainer.train(epoch, train_loader, optimizer,\n                      print_freq=args.print_freq, train_iters=len(train_loader))\n\n        if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):\n            mAP = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=False)\n            is_best = (mAP > best_mAP)\n            best_mAP = max(mAP, best_mAP)\n            save_checkpoint({\n                'state_dict': model.state_dict(),\n                'epoch': epoch + 1,\n                'best_mAP': best_mAP,\n            }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))\n                  \n            print('\\n * Finished epoch {:3d}  model mAP: {:5.1%}  best: {:5.1%}{}\\n'.\n                  format(epoch, mAP, best_mAP, ' *' if is_best else ''))\n\n        lr_scheduler.step()\n\n    print('==> Test with the best model:')\n    checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))\n    model.load_state_dict(checkpoint['state_dict'])\n    evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True)\n\n    end_time = time.monotonic()\n    print('Total running time: ', timedelta(seconds=end_time - start_time))\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description=\"Hard-sample Guided Hybrid Contrast Learning for Unsupervised Person Re-ID\")\n    # data\n    parser.add_argument('-d', '--dataset', type=str, default='dukemtmcreid',\n                        choices=datasets.names())\n    parser.add_argument('-b', '--batch-size', type=int, default=2)\n    parser.add_argument('-j', '--workers', type=int, default=4)\n    parser.add_argument('--height', type=int, default=256, help=\"input height\")\n    parser.add_argument('--width', type=int, default=128, help=\"input width\")\n    parser.add_argument('--num-instances', type=int, default=4,\n                        help=\"each minibatch consist of \"\n                             \"(batch_size // num_instances) identities, and \"\n                             \"each identity has num_instances instances, \"\n                             \"default: 0 (NOT USE)\")\n    # cluster\n    parser.add_argument('--eps', type=float, default=0.6,\n                        help=\"max neighbor distance for DBSCAN\")\n    parser.add_argument('--eps-gap', type=float, default=0.02,\n                        help=\"multi-scale criterion for measuring cluster reliability\")\n    parser.add_argument('--k1', type=int, default=30,\n                        help=\"hyperparameter for jaccard distance\")\n    parser.add_argument('--k2', type=int, default=6,\n                        help=\"hyperparameter for jaccard distance\")\n\n    # model\n    parser.add_argument('-a', '--arch', type=str, default='resnet50',\n                        choices=models.names())\n    parser.add_argument('--features', type=int, default=0)\n    parser.add_argument('--dropout', type=float, default=0)\n    parser.add_argument('--smooth', type=float, default=0, help=\"label smoothing\")\n    parser.add_argument('--hard-weight', type=float, default=0.5, help=\"hard weights\")\n    parser.add_argument('--momentum', type=float, default=0.1,\n                        help=\"update momentum for the memory bank\")\n    parser.add_argument('--pooling-type', type=str, default='gem')\n    parser.add_argument('-mb', '--memorybank', type=str, default='CM', choices=['CM', 'CMhard', 'CMhybrid', 'CMhybrid_v2'])\n\n    # optimizer\n    parser.add_argument('--lr', type=float, default=0.00035,\n                        help=\"learning rate\")\n    parser.add_argument('--weight-decay', type=float, default=5e-4)\n    parser.add_argument('--epochs', type=int, default=50)\n    parser.add_argument('--iters', type=int, default=400)\n    parser.add_argument('--step-size', type=int, default=20)\n    # training configs\n    parser.add_argument('--seed', type=int, default=1)\n    parser.add_argument('--print-freq', type=int, default=10)\n    parser.add_argument('--eval-step', type=int, default=10)\n    parser.add_argument('--temp', type=float, default=0.05,\n                        help=\"temperature for scaling contrastive loss\")\n    # path\n    working_dir = osp.dirname(osp.abspath(__file__))\n    parser.add_argument('--data-dir', type=str, metavar='PATH',\n                        default=osp.join(working_dir, 'data'))\n    parser.add_argument('--logs-dir', type=str, metavar='PATH',\n                        default=osp.join(working_dir, 'logs'))\n    parser.add_argument('--resume', type=str, metavar='PATH', default='')\n    main()\n"
  },
  {
    "path": "hhcl/__init__.py",
    "content": "from __future__ import absolute_import\n\nfrom . import datasets\nfrom . import evaluation_metrics\nfrom . import models\nfrom . import utils\nfrom . import evaluators\nfrom . import trainers\n\n__version__ = '0.1.0'\n"
  },
  {
    "path": "hhcl/datasets/__init__.py",
    "content": "from __future__ import absolute_import\nimport warnings\n\nfrom .market1501 import Market1501\nfrom .msmt17 import MSMT17\nfrom .personx import PersonX\nfrom .dukemtmcreid import DukeMTMCreID\nfrom .celebreid import CelebReID\n\n\n__factory = {\n    'market1501': Market1501,\n    'msmt17': MSMT17,\n    'personx': PersonX,\n    'dukemtmcreid': DukeMTMCreID,\n    'celebreid': CelebReID\n}\n\n\ndef names():\n    return sorted(__factory.keys())\n\n\ndef create(name, root, *args, **kwargs):\n    \"\"\"\n    Create a dataset instance.\n\n    Parameters\n    ----------\n    name : str\n        The dataset name. \n    root : str\n        The path to the dataset directory.\n    split_id : int, optional\n        The index of data split. Default: 0\n    num_val : int or float, optional\n        When int, it means the number of validation identities. When float,\n        it means the proportion of validation to all the trainval. Default: 100\n    download : bool, optional\n        If True, will download the dataset. Default: False\n    \"\"\"\n    if name not in __factory:\n        raise KeyError(\"Unknown dataset:\", name)\n    return __factory[name](root, *args, **kwargs)\n\n\ndef get_dataset(name, root, *args, **kwargs):\n    warnings.warn(\"get_dataset is deprecated. Use create instead.\")\n    return create(name, root, *args, **kwargs)\n"
  },
  {
    "path": "hhcl/datasets/celebreid.py",
    "content": "from __future__ import print_function, absolute_import\nimport os.path as osp\nimport glob\nimport re\nfrom ..utils.data import BaseImageDataset\n\n\nclass CelebReID(BaseImageDataset):\n    \"\"\"\n    CelebReID\n    \"\"\"\n    dataset_dir = 'CelebReID'\n\n    def __init__(self, root, verbose=True, **kwargs):\n        super(CelebReID, self).__init__()\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'gallery')\n\n        self._check_before_run()\n\n        train = self._process_dir(self.train_dir, relabel=True)\n        query = self._process_dir(self.query_dir, relabel=False)\n        gallery = self._process_dir(self.gallery_dir, relabel=False)\n\n        if verbose:\n            print(\"=> CelebReID loaded\")\n            self.print_dataset_statistics(train, query, gallery)\n\n        self.train = train\n        self.query = query\n        self.gallery = gallery\n\n        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)\n        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.query_dir))\n        if not osp.exists(self.gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.gallery_dir))\n\n    def _process_dir(self, dir_path, relabel=False):\n        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))\n        # pattern = re.compile(r'([-\\d]+)_c(\\d)')\n        pattern = re.compile(r'([-\\d]+)_(\\d)')\n\n        pid_container = set()\n        for img_path in img_paths:\n            pid, _ = map(int, pattern.search(img_path).groups())\n            if pid == -1:\n                continue  # junk images are just ignored\n            pid_container.add(pid)\n        pid2label = {pid: label for label, pid in enumerate(pid_container)}\n\n        dataset = []\n        for img_path in img_paths:\n            pid, camid = map(int, pattern.search(img_path).groups())\n            if pid == -1:\n                continue  # junk images are just ignored\n            # assert 0 <= pid <= 1501  # pid == 0 means background\n            # assert 1 <= camid <= 6\n            camid -= 1  # index starts from 0\n            if relabel:\n                pid = pid2label[pid]\n            dataset.append((img_path, pid, camid))\n\n        return dataset\n"
  },
  {
    "path": "hhcl/datasets/dukemtmcreid.py",
    "content": "import glob\nimport os.path as osp\nimport re\nfrom ..utils.data import BaseImageDataset\n\n\ndef process_dir(dir_path, relabel=False):\n    img_paths = glob.glob(osp.join(dir_path, \"*.jpg\"))\n    pattern = re.compile(r\"([-\\d]+)_c(\\d)\")\n\n    # get all identities\n    pid_container = set()\n    for img_path in img_paths:\n        pid, _ = map(int, pattern.search(img_path).groups())\n        if pid == -1:\n            continue\n        pid_container.add(pid)\n\n    pid2label = {pid: label for label, pid in enumerate(pid_container)}\n\n    data = []\n    for img_path in img_paths:\n        pid, camid = map(int, pattern.search(img_path).groups())\n        if (pid not in pid_container) or (pid == -1):\n            continue\n\n        assert 1 <= camid <= 8\n        camid -= 1\n\n        if relabel:\n            pid = pid2label[pid]\n        data.append((img_path, pid, camid))\n\n    return data\n\n\nclass DukeMTMCreID(BaseImageDataset):\n\n    \"\"\"DukeMTMC-reID.\n    Reference:\n        - Ristani et al. Performance Measures and a Data Set for Multi-Target,\n            Multi-Camera Tracking. ECCVW 2016.\n        - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person\n            Re-identification Baseline in vitro. ICCV 2017.\n    URL: `<https://github.com/layumi/DukeMTMC-reID_evaluation>`_\n\n    Dataset statistics:\n        - identities: 1404 (train + query).\n        - images:16522 (train) + 2228 (query) + 17661 (gallery).\n        - cameras: 8.\n    \"\"\"\n\n    dataset_dir = \"DukeMTMC-reID\"\n\n    def __init__(self, root, verbose=True):\n        super(DukeMTMCreID, self).__init__()\n        self.root = osp.abspath(osp.expanduser(root))\n        self.dataset_dir = osp.join(self.root, self.dataset_dir)\n\n        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')\n\n        train = process_dir(dir_path=self.train_dir, relabel=True)\n        query = process_dir(dir_path=self.query_dir, relabel=False)\n        gallery = process_dir(dir_path=self.gallery_dir, relabel=False)\n\n        self.train = train\n        self.query = query\n        self.gallery = gallery\n\n        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)\n        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.query_dir))\n        if not osp.exists(self.gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.gallery_dir))\n"
  },
  {
    "path": "hhcl/datasets/market1501.py",
    "content": "from __future__ import print_function, absolute_import\nimport os.path as osp\nimport glob\nimport re\nfrom ..utils.data import BaseImageDataset\n\n\nclass Market1501(BaseImageDataset):\n    \"\"\"\n    Market1501\n    Reference:\n    Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.\n    URL: http://www.liangzheng.org/Project/project_reid.html\n\n    Dataset statistics:\n    # identities: 1501 (+1 for background)\n    # images: 12936 (train) + 3368 (query) + 15913 (gallery)\n    \"\"\"\n    dataset_dir = 'Market-1501-v15.09.15'\n\n    def __init__(self, root, verbose=True, **kwargs):\n        super(Market1501, self).__init__()\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')\n\n        self._check_before_run()\n\n        train = self._process_dir(self.train_dir, relabel=True)\n        query = self._process_dir(self.query_dir, relabel=False)\n        gallery = self._process_dir(self.gallery_dir, relabel=False)\n\n        if verbose:\n            print(\"=> Market1501 loaded\")\n            self.print_dataset_statistics(train, query, gallery)\n\n        self.train = train\n        self.query = query\n        self.gallery = gallery\n\n        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)\n        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.query_dir))\n        if not osp.exists(self.gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.gallery_dir))\n\n    def _process_dir(self, dir_path, relabel=False):\n        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))\n        pattern = re.compile(r'([-\\d]+)_c(\\d)')\n\n        pid_container = set()\n        for img_path in img_paths:\n            pid, _ = map(int, pattern.search(img_path).groups())\n            if pid == -1:\n                continue  # junk images are just ignored\n            pid_container.add(pid)\n        pid2label = {pid: label for label, pid in enumerate(pid_container)}\n\n        dataset = []\n        for img_path in img_paths:\n            pid, camid = map(int, pattern.search(img_path).groups())\n            if pid == -1:\n                continue  # junk images are just ignored\n            assert 0 <= pid <= 1501  # pid == 0 means background\n            assert 1 <= camid <= 6\n            camid -= 1  # index starts from 0\n            if relabel:\n                pid = pid2label[pid]\n            dataset.append((img_path, pid, camid))\n\n        return dataset\n"
  },
  {
    "path": "hhcl/datasets/msmt17.py",
    "content": "from __future__ import print_function, absolute_import\nimport os.path as osp\n\nimport glob\nimport re\nfrom ..utils.data import BaseImageDataset\n\n\ndef _process_dir(dir_path, relabel=False):\n    img_paths = glob.glob(osp.join(dir_path, '*.jpg'))\n    pattern = re.compile(r'([-\\d]+)_c(\\d+)')\n\n    pid_container = set()\n    for img_path in img_paths:\n        pid, _ = map(int, pattern.search(img_path).groups())\n        if pid == -1:\n            continue  # junk images are just ignored\n        pid_container.add(pid)\n    pid2label = {pid: label for label, pid in enumerate(pid_container)}\n    dataset = []\n    for img_path in img_paths:\n        pid, camid = map(int, pattern.search(img_path).groups())\n        if pid == -1:\n            continue  # junk images are just ignored\n        assert 1 <= camid <= 15\n        camid -= 1  # index starts from 0\n        if relabel:\n            pid = pid2label[pid]\n        dataset.append((img_path, pid, camid))\n\n    return dataset\n\n\nclass MSMT17(BaseImageDataset):\n    dataset_dir = 'MSMT17_V1'\n\n    def __init__(self, root, verbose=True, **kwargs):\n        super(MSMT17, self).__init__()\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')\n\n        self._check_before_run()\n\n        train = _process_dir(self.train_dir, relabel=True)\n        query = _process_dir(self.query_dir, relabel=False)\n        gallery = _process_dir(self.gallery_dir, relabel=False)\n\n        if verbose:\n            print(\"=> MSMT17_V1 loaded\")\n            self.print_dataset_statistics(train, query, gallery)\n\n            self.train = train\n            self.query = query\n            self.gallery = gallery\n\n            self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n            self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)\n            self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.query_dir))\n        if not osp.exists(self.gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.gallery_dir))\n"
  },
  {
    "path": "hhcl/datasets/personx.py",
    "content": "from __future__ import print_function, absolute_import\nimport os.path as osp\nimport glob\nimport re\n\nfrom ..utils.data import BaseImageDataset\n\n\nclass PersonX(BaseImageDataset):\n    \"\"\"\n    PersonX\n    Reference:\n    Sun et al. Dissecting Person Re-identification from the Viewpoint of Viewpoint. CVPR 2019.\n\n    Dataset statistics:\n    # identities: 1266\n    # images: 9840 (train) + 5136 (query) + 30816 (gallery)\n    \"\"\"\n    dataset_dir = 'PersonX'\n\n    def __init__(self, root, verbose=True, **kwargs):\n        super(PersonX, self).__init__()\n        self.dataset_dir = osp.join(root, self.dataset_dir)\n        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')\n        self.query_dir = osp.join(self.dataset_dir, 'query')\n        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')\n\n        self._check_before_run()\n\n        train = self._process_dir(self.train_dir, relabel=True)\n        query = self._process_dir(self.query_dir, relabel=False)\n        gallery = self._process_dir(self.gallery_dir, relabel=False)\n\n        if verbose:\n            print(\"=> PersonX loaded\")\n            self.print_dataset_statistics(train, query, gallery)\n\n        self.train = train\n        self.query = query\n        self.gallery = gallery\n\n        self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)\n        self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)\n        self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)\n\n    def _check_before_run(self):\n        \"\"\"Check if all files are available before going deeper\"\"\"\n        if not osp.exists(self.dataset_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.dataset_dir))\n        if not osp.exists(self.train_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.train_dir))\n        if not osp.exists(self.query_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.query_dir))\n        if not osp.exists(self.gallery_dir):\n            raise RuntimeError(\"'{}' is not available\".format(self.gallery_dir))\n\n    def _process_dir(self, dir_path, relabel=False):\n        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))\n        pattern = re.compile(r'([-\\d]+)_c([-\\d]+)')\n        cam2label = {3: 1, 4: 2, 8: 3, 10: 4, 11: 5, 12: 6}\n\n        pid_container = set()\n        for img_path in img_paths:\n            pid, _ = map(int, pattern.search(img_path).groups())\n            pid_container.add(pid)\n        pid2label = {pid: label for label, pid in enumerate(pid_container)}\n\n        dataset = []\n        for img_path in img_paths:\n            pid, camid = map(int, pattern.search(img_path).groups())\n            assert (camid in cam2label.keys())\n            camid = cam2label[camid]\n            camid -= 1  # index starts from 0\n            if relabel: pid = pid2label[pid]\n            dataset.append((img_path, pid, camid))\n\n        return dataset\n"
  },
  {
    "path": "hhcl/evaluation_metrics/__init__.py",
    "content": "from __future__ import absolute_import\n\nfrom .classification import accuracy\nfrom .ranking import cmc, mean_ap\n\n__all__ = [\n    'accuracy',\n    'cmc',\n    'mean_ap'\n]\n"
  },
  {
    "path": "hhcl/evaluation_metrics/classification.py",
    "content": "from __future__ import absolute_import\n\nimport torch\nfrom ..utils import to_torch\n\n\ndef accuracy(output, target, topk=(1,)):\n    with torch.no_grad():\n        output, target = to_torch(output), to_torch(target)\n        maxk = max(topk)\n        batch_size = target.size(0)\n\n        _, pred = output.topk(maxk, 1, True, True)\n        pred = pred.t()\n        correct = pred.eq(target.view(1, -1).expand_as(pred))\n\n        ret = []\n        for k in topk:\n            correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)\n            ret.append(correct_k.mul_(1. / batch_size))\n        return ret\n"
  },
  {
    "path": "hhcl/evaluation_metrics/ranking.py",
    "content": "from __future__ import absolute_import\nfrom collections import defaultdict\n\nimport numpy as np\nfrom sklearn.metrics import average_precision_score\n\nfrom ..utils import to_numpy\n\n\ndef _unique_sample(ids_dict, num):\n    mask = np.zeros(num, dtype=np.bool)\n    for _, indices in ids_dict.items():\n        i = np.random.choice(indices)\n        mask[i] = True\n    return mask\n\n\ndef cmc(distmat, query_ids=None, gallery_ids=None,\n        query_cams=None, gallery_cams=None, topk=100,\n        separate_camera_set=False,\n        single_gallery_shot=False,\n        first_match_break=False):\n    distmat = to_numpy(distmat)\n    m, n = distmat.shape\n    # Fill up default values\n    if query_ids is None:\n        query_ids = np.arange(m)\n    if gallery_ids is None:\n        gallery_ids = np.arange(n)\n    if query_cams is None:\n        query_cams = np.zeros(m).astype(np.int32)\n    if gallery_cams is None:\n        gallery_cams = np.ones(n).astype(np.int32)\n    # Ensure numpy array\n    query_ids = np.asarray(query_ids)\n    gallery_ids = np.asarray(gallery_ids)\n    query_cams = np.asarray(query_cams)\n    gallery_cams = np.asarray(gallery_cams)\n    # Sort and find correct matches\n    indices = np.argsort(distmat, axis=1)\n    matches = (gallery_ids[indices] == query_ids[:, np.newaxis])\n    # Compute CMC for each query\n    ret = np.zeros(topk)\n    num_valid_queries = 0\n    for i in range(m):\n        # Filter out the same id and same camera\n        valid = ((gallery_ids[indices[i]] != query_ids[i]) |\n                 (gallery_cams[indices[i]] != query_cams[i]))\n        if separate_camera_set:\n            # Filter out samples from same camera\n            valid &= (gallery_cams[indices[i]] != query_cams[i])\n        if not np.any(matches[i, valid]): continue\n        if single_gallery_shot:\n            repeat = 10\n            gids = gallery_ids[indices[i][valid]]\n            inds = np.where(valid)[0]\n            ids_dict = defaultdict(list)\n            for j, x in zip(inds, gids):\n                ids_dict[x].append(j)\n        else:\n            repeat = 1\n        for _ in range(repeat):\n            if single_gallery_shot:\n                # Randomly choose one instance for each id\n                sampled = (valid & _unique_sample(ids_dict, len(valid)))\n                index = np.nonzero(matches[i, sampled])[0]\n            else:\n                index = np.nonzero(matches[i, valid])[0]\n            delta = 1. / (len(index) * repeat)\n            for j, k in enumerate(index):\n                if k - j >= topk: break\n                if first_match_break:\n                    ret[k - j] += 1\n                    break\n                ret[k - j] += delta\n        num_valid_queries += 1\n    if num_valid_queries == 0:\n        raise RuntimeError(\"No valid query\")\n    return ret.cumsum() / num_valid_queries\n\n\ndef mean_ap(distmat, query_ids=None, gallery_ids=None,\n            query_cams=None, gallery_cams=None):\n    distmat = to_numpy(distmat)\n    m, n = distmat.shape\n    # Fill up default values\n    if query_ids is None:\n        query_ids = np.arange(m)\n    if gallery_ids is None:\n        gallery_ids = np.arange(n)\n    if query_cams is None:\n        query_cams = np.zeros(m).astype(np.int32)\n    if gallery_cams is None:\n        gallery_cams = np.ones(n).astype(np.int32)\n    # Ensure numpy array\n    query_ids = np.asarray(query_ids)\n    gallery_ids = np.asarray(gallery_ids)\n    query_cams = np.asarray(query_cams)\n    gallery_cams = np.asarray(gallery_cams)\n    # Sort and find correct matches\n    indices = np.argsort(distmat, axis=1)\n    matches = (gallery_ids[indices] == query_ids[:, np.newaxis])\n    # Compute AP for each query\n    aps = []\n    for i in range(m):\n        # Filter out the same id and same camera\n        valid = ((gallery_ids[indices[i]] != query_ids[i]) |\n                 (gallery_cams[indices[i]] != query_cams[i]))\n        y_true = matches[i, valid]\n        y_score = -distmat[i][indices[i]][valid]\n        if not np.any(y_true): continue\n        aps.append(average_precision_score(y_true, y_score))\n    if len(aps) == 0:\n        raise RuntimeError(\"No valid query\")\n    return np.mean(aps)\n"
  },
  {
    "path": "hhcl/evaluators.py",
    "content": "from __future__ import print_function, absolute_import\nimport time\nimport collections\nfrom collections import OrderedDict\nimport numpy as np\nimport torch\nimport random\nimport copy\n\nfrom .evaluation_metrics import cmc, mean_ap\nfrom .utils.meters import AverageMeter\nfrom .utils.rerank import re_ranking\nfrom .utils import to_torch\n\n\ndef extract_cnn_feature(model, inputs):\n    inputs = to_torch(inputs).cuda()\n    outputs = model(inputs)\n    outputs = outputs.data.cpu()\n    return outputs\n\n\ndef extract_features(model, data_loader, print_freq=50):\n    model.eval()\n    batch_time = AverageMeter()\n    data_time = AverageMeter()\n\n    features = OrderedDict()\n    labels = OrderedDict()\n\n    end = time.time()\n    with torch.no_grad():\n        for i, (imgs, fnames, pids, _, _) in enumerate(data_loader):\n            data_time.update(time.time() - end)\n\n            outputs = extract_cnn_feature(model, imgs)\n            for fname, output, pid in zip(fnames, outputs, pids):\n                features[fname] = output\n                labels[fname] = pid\n\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if (i + 1) % print_freq == 0:\n                print('Extract Features: [{}/{}]\\t'\n                      'Time {:.3f} ({:.3f})\\t'\n                      'Data {:.3f} ({:.3f})\\t'\n                      .format(i + 1, len(data_loader),\n                              batch_time.val, batch_time.avg,\n                              data_time.val, data_time.avg))\n\n    return features, labels\n\n\ndef pairwise_distance(features, query=None, gallery=None):\n    if query is None and gallery is None:\n        n = len(features)\n        x = torch.cat(list(features.values()))\n        x = x.view(n, -1)\n        dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2\n        dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t())\n        return dist_m\n\n    x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0)\n    y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0)\n    m, n = x.size(0), y.size(0)\n    x = x.view(m, -1)\n    y = y.view(n, -1)\n    dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \\\n           torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()\n    dist_m.addmm_(1, -2, x, y.t())\n    return dist_m, x.numpy(), y.numpy()\n\n\ndef evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None,\n                 query_ids=None, gallery_ids=None,\n                 query_cams=None, gallery_cams=None,\n                 cmc_topk=(1, 5, 10), cmc_flag=False):\n    if query is not None and gallery is not None:\n        query_ids = [pid for _, pid, _ in query]\n        gallery_ids = [pid for _, pid, _ in gallery]\n        query_cams = [cam for _, _, cam in query]\n        gallery_cams = [cam for _, _, cam in gallery]\n    else:\n        assert (query_ids is not None and gallery_ids is not None\n                and query_cams is not None and gallery_cams is not None)\n\n    # Compute mean AP\n    mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams)\n    print('Mean AP: {:4.1%}'.format(mAP))\n\n\n    cmc_configs = {\n        'market1501': dict(separate_camera_set=False,\n                           single_gallery_shot=False,\n                           first_match_break=True),}\n    cmc_scores = {name: cmc(distmat, query_ids, gallery_ids,\n                            query_cams, gallery_cams, **params)\n                  for name, params in cmc_configs.items()}\n\n    print('CMC Scores:')\n    for k in cmc_topk:\n        print('  top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k-1]))\n    if (not cmc_flag):\n        return mAP\n    return cmc_scores['market1501'], mAP\n\n\nclass Evaluator(object):\n    def __init__(self, model):\n        super(Evaluator, self).__init__()\n        self.model = model\n\n    def evaluate(self, data_loader, query, gallery, cmc_flag=False, rerank=False):\n        features, _ = extract_features(self.model, data_loader)\n        distmat, query_features, gallery_features = pairwise_distance(features, query, gallery)\n        results = evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag)\n\n        if (not rerank):\n            return results\n\n        print('Applying person re-ranking ...')\n        distmat_qq, _, _ = pairwise_distance(features, query, query)\n        distmat_gg, _, _ = pairwise_distance(features, gallery, gallery)\n        distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy())\n        return evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag)\n"
  },
  {
    "path": "hhcl/models/__init__.py",
    "content": "from __future__ import absolute_import\n\nfrom .resnet import *\nfrom .resnet_ibn import *\n\n__factory = {\n    'resnet18': resnet18,\n    'resnet34': resnet34,\n    'resnet50': resnet50,\n    'resnet101': resnet101,\n    'resnet152': resnet152,\n    'resnet_ibn50a': resnet_ibn50a,\n    'resnet_ibn101a': resnet_ibn101a,\n}\n\n\ndef names():\n    return sorted(__factory.keys())\n\n\ndef create(name, *args, **kwargs):\n    \"\"\"\n    Create a model instance.\n\n    Parameters\n    ----------\n    name : str\n        Model name. Can be one of 'inception', 'resnet18', 'resnet34',\n        'resnet50', 'resnet101', and 'resnet152'.\n    pretrained : bool, optional\n        Only applied for 'resnet*' models. If True, will use ImageNet pretrained\n        model. Default: True\n    cut_at_pooling : bool, optional\n        If True, will cut the model before the last global pooling layer and\n        ignore the remaining kwargs. Default: False\n    num_features : int, optional\n        If positive, will append a Linear layer after the global pooling layer,\n        with this number of output units, followed by a BatchNorm layer.\n        Otherwise these layers will not be appended. Default: 256 for\n        'inception', 0 for 'resnet*'\n    norm : bool, optional\n        If True, will normalize the feature to be unit L2-norm for each sample.\n        Otherwise will append a ReLU layer after the above Linear layer if\n        num_features > 0. Default: False\n    dropout : float, optional\n        If positive, will append a Dropout layer with this dropout rate.\n        Default: 0\n    num_classes : int, optional\n        If positive, will append a Linear layer at the end as the classifier\n        with this number of output units. Default: 0\n    \"\"\"\n    if name not in __factory:\n        raise KeyError(\"Unknown model:\", name)\n    return __factory[name](*args, **kwargs)\n"
  },
  {
    "path": "hhcl/models/cm.py",
    "content": "import collections\nimport numpy as np\nfrom abc import ABC\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, autograd\nfrom .losses import CrossEntropyLabelSmooth, FocalTopLoss\n\n\nclass CM(autograd.Function):\n\n    @staticmethod\n    def forward(ctx, inputs, targets, features, momentum):\n        ctx.features = features\n        ctx.momentum = momentum\n        ctx.save_for_backward(inputs, targets)\n        outputs = inputs.mm(ctx.features.t())\n\n        return outputs\n\n    @staticmethod\n    def backward(ctx, grad_outputs):\n        inputs, targets = ctx.saved_tensors\n        grad_inputs = None\n        if ctx.needs_input_grad[0]:\n            grad_inputs = grad_outputs.mm(ctx.features)\n\n        # momentum update\n        for x, y in zip(inputs, targets):\n            ctx.features[y] = ctx.momentum * ctx.features[y] + (1. - ctx.momentum) * x\n            ctx.features[y] /= ctx.features[y].norm()\n\n        return grad_inputs, None, None, None\n\n\ndef cm(inputs, indexes, features, momentum=0.5):\n    return CM.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device))\n\n\nclass CM_Hard(autograd.Function):\n\n    @staticmethod\n    def forward(ctx, inputs, targets, features, momentum):\n        ctx.features = features\n        ctx.momentum = momentum\n        ctx.save_for_backward(inputs, targets)\n        outputs = inputs.mm(ctx.features.t())\n\n        return outputs\n\n    @staticmethod\n    def backward(ctx, grad_outputs):\n        inputs, targets = ctx.saved_tensors\n        grad_inputs = None\n        if ctx.needs_input_grad[0]:\n            grad_inputs = grad_outputs.mm(ctx.features)\n\n        batch_centers = collections.defaultdict(list)\n        for instance_feature, index in zip(inputs, targets.tolist()):\n            batch_centers[index].append(instance_feature)\n\n        for index, features in batch_centers.items():\n            distances = []\n            for feature in features:\n                distance = feature.unsqueeze(0).mm(ctx.features[index].unsqueeze(0).t())[0][0]\n                distances.append(distance.cpu().numpy())\n\n            median = np.argmin(np.array(distances))\n            ctx.features[index] = ctx.features[index] * ctx.momentum + (1 - ctx.momentum) * features[median]\n            ctx.features[index] /= ctx.features[index].norm()\n\n        return grad_inputs, None, None, None\n\n\ndef cm_hard(inputs, indexes, features, momentum=0.5):\n    return CM_Hard.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device))\n\n\nclass CM_Hybrid(autograd.Function):\n\n    @staticmethod\n    def forward(ctx, inputs, targets, features, momentum):\n        ctx.features = features\n        ctx.momentum = momentum\n        ctx.save_for_backward(inputs, targets)\n        outputs = inputs.mm(ctx.features.t())\n\n        return outputs\n\n    @staticmethod\n    def backward(ctx, grad_outputs):\n        inputs, targets = ctx.saved_tensors\n        nums = len(ctx.features)//2\n        grad_inputs = None\n        if ctx.needs_input_grad[0]:\n            grad_inputs = grad_outputs.mm(ctx.features)\n\n        batch_centers = collections.defaultdict(list)\n        for instance_feature, index in zip(inputs, targets.tolist()):\n            batch_centers[index].append(instance_feature)\n\n        for index, features in batch_centers.items():\n            distances = []\n            for feature in features:\n                distance = feature.unsqueeze(0).mm(ctx.features[index].unsqueeze(0).t())[0][0]\n                distances.append(distance.cpu().numpy())\n\n            median = np.argmin(np.array(distances))\n            ctx.features[index] = ctx.features[index] * ctx.momentum + (1 - ctx.momentum) * features[median]\n            ctx.features[index] /= ctx.features[index].norm()\n\n            mean = torch.stack(features, dim=0).mean(0)\n            ctx.features[index+nums] = ctx.features[index+nums] * ctx.momentum + (1 - ctx.momentum) * mean\n            ctx.features[index+nums] /= ctx.features[index+nums].norm()\n\n        return grad_inputs, None, None, None\n\n\ndef cm_hybrid(inputs, indexes, features, momentum=0.5):\n    return CM_Hybrid.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device))\n\n\nclass CM_Hybrid_v2(autograd.Function):\n\n    @staticmethod\n    def forward(ctx, inputs, targets, features, momentum, num_instances):\n        ctx.features = features\n        ctx.momentum = momentum\n        ctx.num_instances = num_instances\n        ctx.save_for_backward(inputs, targets)\n        outputs = inputs.mm(ctx.features.t())\n\n        return outputs\n\n    @staticmethod\n    def backward(ctx, grad_outputs):\n        inputs, targets = ctx.saved_tensors\n        nums = len(ctx.features)//(ctx.num_instances + 1)\n        grad_inputs = None\n        if ctx.needs_input_grad[0]:\n            grad_inputs = grad_outputs.mm(ctx.features)\n\n        batch_centers = collections.defaultdict(list)\n        updated = set()\n        for k, (instance_feature, index) in enumerate(zip(inputs, targets.tolist())):\n            batch_centers[index].append(instance_feature)\n            if index not in updated:\n                indexes = [index + nums*i for i in range(1, (targets==index).sum()+1)]\n                ctx.features[indexes] = inputs[targets==index]\n                # ctx.features[indexes] = ctx.features[indexes] * ctx.momentum + (1 - ctx.momentum) * inputs[targets==index]\n                # ctx.features[indexes] /= ctx.features[indexes].norm(dim=1, keepdim=True)\n                updated.add(index)\n\n        for index, features in batch_centers.items():\n            mean = torch.stack(features, dim=0).mean(0)\n            ctx.features[index] = ctx.features[index] * ctx.momentum + (1 - ctx.momentum) * mean\n            ctx.features[index] /= ctx.features[index].norm()\n        \n        return grad_inputs, None, None, None, None\n\n\ndef cm_hybrid_v2(inputs, indexes, features, momentum=0.5, num_instances=16, *args):\n    return CM_Hybrid_v2.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device), num_instances)\n\n\nclass ClusterMemory(nn.Module, ABC):\n    \n    __CMfactory = {\n        'CM': cm,\n        'CMhard':cm_hard,\n    }\n\n    def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2, mode='CM', hard_weight=0.5, smooth=0., num_instances=1):\n        super(ClusterMemory, self).__init__()\n        self.num_features = num_features\n        self.num_samples = num_samples\n\n        self.momentum = momentum\n        self.temp = temp\n        self.cm_type = mode\n\n        if smooth>0:\n            self.cross_entropy = CrossEntropyLabelSmooth(self.num_samples, 0.1, True)\n            print('>>> Using CrossEntropy with Label Smoothing.')\n        else: \n            self.cross_entropy = nn.CrossEntropyLoss().cuda() \n\n        if self.cm_type in ['CM', 'CMhard']:\n            self.register_buffer('features', torch.zeros(num_samples, num_features))\n        elif self.cm_type=='CMhybrid':\n            self.hard_weight = hard_weight\n            print('hard_weight: {}'.format(self.hard_weight))\n            self.register_buffer('features', torch.zeros(2*num_samples, num_features))\n        elif self.cm_type=='CMhybrid_v2':\n            self.hard_weight = hard_weight\n            self.num_instances = num_instances\n            self.register_buffer('features', torch.zeros((self.num_instances+1)*num_samples, num_features))\n        else:\n            raise TypeError('Cluster Memory {} is invalid!'.format(self.cm_type))\n\n    def forward(self, inputs, targets):\n\n        if self.cm_type in ['CM', 'CMhard']:\n            outputs = ClusterMemory.__CMfactory[self.cm_type](inputs, targets, self.features, self.momentum)\n            outputs /= self.temp\n            loss = self.cross_entropy(outputs, targets)\n            return loss\n\n        elif self.cm_type=='CMhybrid':\n            outputs = cm_hybrid(inputs, targets, self.features, self.momentum)\n            outputs /= self.temp\n            output_hard, output_mean = torch.chunk(outputs, 2, dim=1)\n            loss = self.hard_weight * (self.cross_entropy(output_hard, targets) + (1 - self.hard_weight) * self.cross_entropy(output_mean, targets))\n            return loss\n\n        elif self.cm_type=='CMhybrid_v2':\n            outputs = cm_hybrid_v2(inputs, targets, self.features, self.momentum, self.num_instances)\n            out_list = torch.chunk(outputs, self.num_instances+1, dim=1)\n            out = torch.stack(out_list[1:], dim=0)\n            neg = torch.max(out, dim=0)[0]\n            pos = torch.min(out, dim=0)[0]\n            mask = torch.zeros_like(out_list[0]).scatter_(1, targets.unsqueeze(1), 1)\n            logits = mask * pos + (1-mask) * neg\n            loss = self.hard_weight * self.cross_entropy(out_list[0]/self.temp, targets) \\\n                + (1 - self.hard_weight) * self.cross_entropy(logits/self.temp, targets)\n            return loss\n"
  },
  {
    "path": "hhcl/models/dsbn.py",
    "content": "import torch\nimport torch.nn as nn\n\n# Domain-specific BatchNorm\n\nclass DSBN2d(nn.Module):\n    def __init__(self, planes):\n        super(DSBN2d, self).__init__()\n        self.num_features = planes\n        self.BN_S = nn.BatchNorm2d(planes)\n        self.BN_T = nn.BatchNorm2d(planes)\n\n    def forward(self, x):\n        if (not self.training):\n            return self.BN_T(x)\n\n        bs = x.size(0)\n        assert (bs%2==0)\n        split = torch.split(x, int(bs/2), 0)\n        out1 = self.BN_S(split[0].contiguous())\n        out2 = self.BN_T(split[1].contiguous())\n        out = torch.cat((out1, out2), 0)\n        return out\n\nclass DSBN1d(nn.Module):\n    def __init__(self, planes):\n        super(DSBN1d, self).__init__()\n        self.num_features = planes\n        self.BN_S = nn.BatchNorm1d(planes)\n        self.BN_T = nn.BatchNorm1d(planes)\n\n    def forward(self, x):\n        if (not self.training):\n            return self.BN_T(x)\n\n        bs = x.size(0)\n        assert (bs%2==0)\n        split = torch.split(x, int(bs/2), 0)\n        out1 = self.BN_S(split[0].contiguous())\n        out2 = self.BN_T(split[1].contiguous())\n        out = torch.cat((out1, out2), 0)\n        return out\n\ndef convert_dsbn(model):\n    for _, (child_name, child) in enumerate(model.named_children()):\n        assert(not next(model.parameters()).is_cuda)\n        if isinstance(child, nn.BatchNorm2d):\n            m = DSBN2d(child.num_features)\n            m.BN_S.load_state_dict(child.state_dict())\n            m.BN_T.load_state_dict(child.state_dict())\n            setattr(model, child_name, m)\n        elif isinstance(child, nn.BatchNorm1d):\n            m = DSBN1d(child.num_features)\n            m.BN_S.load_state_dict(child.state_dict())\n            m.BN_T.load_state_dict(child.state_dict())\n            setattr(model, child_name, m)\n        else:\n            convert_dsbn(child)\n\ndef convert_bn(model, use_target=True):\n    for _, (child_name, child) in enumerate(model.named_children()):\n        assert(not next(model.parameters()).is_cuda)\n        if isinstance(child, DSBN2d):\n            m = nn.BatchNorm2d(child.num_features)\n            if use_target:\n                m.load_state_dict(child.BN_T.state_dict())\n            else:\n                m.load_state_dict(child.BN_S.state_dict())\n            setattr(model, child_name, m)\n        elif isinstance(child, DSBN1d):\n            m = nn.BatchNorm1d(child.num_features)\n            if use_target:\n                m.load_state_dict(child.BN_T.state_dict())\n            else:\n                m.load_state_dict(child.BN_S.state_dict())\n            setattr(model, child_name, m)\n        else:\n            convert_bn(child, use_target=use_target)\n"
  },
  {
    "path": "hhcl/models/kmeans.py",
    "content": "# Written by Yixiao Ge\n\nimport warnings\n\nimport faiss\nimport torch\n\nfrom ..utils import to_numpy, to_torch\n\n__all__ = [\"label_generator_kmeans\"]\n\n\n@torch.no_grad()\ndef label_generator_kmeans(features, num_classes=500, cuda=True):\n\n    assert num_classes, \"num_classes for kmeans is null\"\n\n    # k-means cluster by faiss\n    cluster = faiss.Kmeans(\n        features.size(-1), num_classes, niter=300, verbose=True, gpu=cuda\n    )\n\n    cluster.train(to_numpy(features))\n\n    _, labels = cluster.index.search(to_numpy(features), 1)\n    labels = labels.reshape(-1)\n\n    centers = to_torch(cluster.centroids).float()\n    # labels = to_torch(labels).long()\n\n    # k-means does not have outlier points\n    assert not (-1 in labels)\n\n    return labels, centers, num_classes, None\n"
  },
  {
    "path": "hhcl/models/losses.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import Parameter\n\n\nclass CrossEntropyLabelSmooth(nn.Module):\n\t\"\"\"Cross entropy loss with label smoothing regularizer.\n\tReference:\n\tSzegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.\n\tEquation: y = (1 - epsilon) * y + epsilon / K.\n\tArgs:\n\t\tnum_classes (int): number of classes.\n\t\tepsilon (float): weight.\n\t\"\"\"\n\n\tdef __init__(self, num_classes=0, epsilon=0.1, topk_smoothing=False):\n\t\tsuper(CrossEntropyLabelSmooth, self).__init__()\n\t\tself.num_classes = num_classes\n\t\tself.epsilon = epsilon\n\t\tself.logsoftmax = nn.LogSoftmax(dim=1).cuda()\n\t\tself.k = 1 if not topk_smoothing else self.num_classes//50\n\n\tdef forward(self, inputs, targets):\n\t\t\"\"\"\n\t\tArgs:\n\t\t\tinputs: prediction matrix (before softmax) with shape (batch_size, num_classes)\n\t\t\ttargets: ground truth labels with shape (num_classes)\n\t\t\"\"\"\n\t\tlog_probs = self.logsoftmax(inputs)\n\t\tif self.k >1:\n\t\t\ttopk = torch.argsort(-log_probs)[:,:self.k]\n\t\t\ttargets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1 - self.epsilon)\n\t\t\ttargets += torch.zeros_like(log_probs).scatter_(1, topk, self.epsilon / self.k)\n\t\telse:\n\t\t\ttargets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)\n\t\t\ttargets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes\n\t\tloss = (- targets * log_probs).mean(0).sum()\n\t\treturn loss\n\n\nclass SoftEntropy(nn.Module):\n\tdef __init__(self, input_prob=False):\n\t\tsuper(SoftEntropy, self).__init__()\n\t\tself.input_prob = input_prob\n\t\tself.logsoftmax = nn.LogSoftmax(dim=1).cuda()\n\n\tdef forward(self, inputs, targets):\n\t\tlog_probs = self.logsoftmax(inputs)\n\t\tif self.input_prob:\n\t\t\tloss = (- targets.detach() * log_probs).mean(0).sum()\n\t\telse:\n\t\t\tloss = (- F.softmax(targets, dim=1).detach() * log_probs).mean(0).sum()\n\t\treturn loss\n\n\nclass SoftEntropySmooth(nn.Module):\n\tdef __init__(self, epsilon=0.1):\n\t\tsuper(SoftEntropySmooth, self).__init__()\n\t\tself.epsilon = epsilon\n\t\tself.logsoftmax = nn.LogSoftmax(dim=1).cuda()\n\n\tdef forward(self, inputs, soft_targets, targets):\n\t\tlog_probs = self.logsoftmax(inputs)\n\t\ttargets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)\n\t\tsoft_targets = F.softmax(soft_targets, dim=1)\n\t\tsmooth_targets = (1 - self.epsilon) * targets + self.epsilon * soft_targets\n\t\tloss = (- smooth_targets.detach() * log_probs).mean(0).sum()\n\t\treturn loss\n\t\n\nclass Softmax(nn.Module):\n\n\tdef __init__(self, feat_dim, num_class, temp=0.05):\n\t\tsuper(Softmax, self).__init__()\n\t\tself.weight = Parameter(torch.Tensor(feat_dim, num_class))\n\t\tself.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)\n\t\tself.temp = temp\n\n\tdef forward(self, feats, labels):\n\t\tkernel_norm = F.normalize(self.weight, dim=0)\n\t\tfeats = F.normalize(feats)\n\t\toutputs = feats.mm(kernel_norm)\n\t\toutputs /= self.temp\n\t\tloss = F.cross_entropy(outputs, labels)\n\t\treturn loss\n\n\nclass CircleLoss(nn.Module):\n    \"\"\"Implementation for \"Circle Loss: A Unified Perspective of Pair Similarity Optimization\"\n    Note: this is the classification based implementation of circle loss.\n    \"\"\"\n    def __init__(self, feat_dim, num_class, margin=0.25, gamma=256):\n        super(CircleLoss, self).__init__()\n        self.weight = Parameter(torch.Tensor(feat_dim, num_class))\n        self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)\n        self.margin = margin\n        self.gamma = gamma\n\n        self.O_p = 1 + margin\n        self.O_n = -margin\n        self.delta_p = 1-margin\n        self.delta_n = margin\n\n    def forward(self, feats, labels):\n        kernel_norm = F.normalize(self.weight, dim=0)\n        feats = F.normalize(feats)\n        cos_theta = torch.mm(feats, kernel_norm) \n        cos_theta = cos_theta.clamp(-1, 1)\n        index_pos = torch.zeros_like(cos_theta)        \n        index_pos.scatter_(1, labels.data.view(-1, 1), 1)\n        index_pos = index_pos.bool()\n        index_neg = torch.ones_like(cos_theta)        \n        index_neg.scatter_(1, labels.data.view(-1, 1), 0)\n        index_neg = index_neg.bool()\n\n        alpha_p = torch.clamp_min(self.O_p - cos_theta.detach(), min=0.)\n        alpha_n = torch.clamp_min(cos_theta.detach() - self.O_n, min=0.)\n\n        logit_p = alpha_p * (cos_theta - self.delta_p)\n        logit_n = alpha_n * (cos_theta - self.delta_n)\n\n        output = cos_theta * 1.0\n        output[index_pos] = logit_p[index_pos]\n        output[index_neg] = logit_n[index_neg]\n        output *= self.gamma\n\n        return F.cross_entropy(output, labels)\n\n\nclass CosFace(nn.Module):\n    r\"\"\"Implement of CosFace (https://arxiv.org/pdf/1801.09414.pdf):\n    Args:\n        in_features: size of each input sample\n        out_features: size of each output sample\n        s: norm of input feature\n        m: margin\n        cos(theta)-m\n    \"\"\"\n    def __init__(self, feat_dim, num_class, s = 64.0, m = 0.35):\n        super(CosFace, self).__init__()\n        self.in_features = feat_dim\n        self.out_features = num_class\n        self.s = s\n        self.m = m\n\n        self.weight = Parameter(torch.FloatTensor(feat_dim, num_class))\n        nn.init.xavier_uniform_(self.weight)\n\n    def forward(self, input, label):\n        # --------------------------- cos(theta) & phi(theta) ---------------------------\n        # cosine = F.linear(F.normalize(input), F.normalize(self.weight, dim=1))\n        cosine = torch.mm(F.normalize(input), F.normalize(self.weight, dim=0)) \n        phi = cosine - self.m\n        # --------------------------- convert label to one-hot ---------------------------\n        one_hot = torch.zeros(cosine.size(), device = 'cuda')\n        # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot\n        one_hot.scatter_(1, label.view(-1, 1).long(), 1)\n        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------\n        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4\n        output *= self.s\n\n        return F.cross_entropy(output, label)\n\n    def __repr__(self):\n        return self.__class__.__name__ + '(' \\\n               + 'in_features = ' + str(self.in_features) \\\n               + ', out_features = ' + str(self.out_features) \\\n               + ', s = ' + str(self.s) \\\n               + ', m = ' + str(self.m) + ')'\n\n\nimport math\n\nclass InstanceLoss(nn.Module):\n    def __init__(self, batch_size, temperature, device):\n        super(InstanceLoss, self).__init__()\n        self.batch_size = batch_size\n        self.temperature = temperature\n        self.device = device\n\n        self.mask = self.mask_correlated_samples(batch_size)\n        self.criterion = nn.CrossEntropyLoss(reduction=\"sum\")\n\n    def mask_correlated_samples(self, batch_size):\n        N = 2 * batch_size\n        mask = torch.ones((N, N))\n        mask = mask.fill_diagonal_(0)\n        for i in range(batch_size):\n            mask[i, batch_size + i] = 0\n            mask[batch_size + i, i] = 0\n        mask = mask.bool()\n        return mask\n\n    def forward(self, z_i, z_j):\n        N = 2 * self.batch_size\n        z = torch.cat((z_i, z_j), dim=0)\n\n        sim = torch.matmul(z, z.T) / self.temperature\n        sim_i_j = torch.diag(sim, self.batch_size)\n        sim_j_i = torch.diag(sim, -self.batch_size)\n\n        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)\n        negative_samples = sim[self.mask].reshape(N, -1)\n\n        labels = torch.zeros(N).to(positive_samples.device).long()\n        logits = torch.cat((positive_samples, negative_samples), dim=1)\n        loss = self.criterion(logits, labels)\n        loss /= N\n\n        return loss\n\n\nclass ClusterLoss(nn.Module):\n    def __init__(self, class_num, temperature, device):\n        super(ClusterLoss, self).__init__()\n        self.class_num = class_num\n        self.temperature = temperature\n        self.device = device\n\n        self.mask = self.mask_correlated_clusters(class_num)\n        self.criterion = nn.CrossEntropyLoss(reduction=\"sum\")\n        self.similarity_f = nn.CosineSimilarity(dim=2)\n\n    def mask_correlated_clusters(self, class_num):\n        N = 2 * class_num\n        mask = torch.ones((N, N))\n        mask = mask.fill_diagonal_(0)\n        for i in range(class_num):\n            mask[i, class_num + i] = 0\n            mask[class_num + i, i] = 0\n        mask = mask.bool()\n        return mask\n\n    def forward(self, c_i, c_j):\n        p_i = c_i.sum(0).view(-1)\n        p_i /= p_i.sum()\n        ne_i = math.log(p_i.size(0)) + (p_i * torch.log(p_i)).sum()\n        p_j = c_j.sum(0).view(-1)\n        p_j /= p_j.sum()\n        ne_j = math.log(p_j.size(0)) + (p_j * torch.log(p_j)).sum()\n        ne_loss = ne_i + ne_j\n\n        c_i = c_i.t()\n        c_j = c_j.t()\n        N = 2 * self.class_num\n        c = torch.cat((c_i, c_j), dim=0)\n\n        sim = self.similarity_f(c.unsqueeze(1), c.unsqueeze(0)) / self.temperature\n        sim_i_j = torch.diag(sim, self.class_num)\n        sim_j_i = torch.diag(sim, -self.class_num)\n\n        positive_clusters = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)\n        negative_clusters = sim[self.mask].reshape(N, -1)\n\n        labels = torch.zeros(N).to(positive_clusters.device).long()\n        logits = torch.cat((positive_clusters, negative_clusters), dim=1)\n        loss = self.criterion(logits, labels)\n        loss /= N\n\n        return loss + ne_loss\n\n\nclass FocalLoss(nn.Module):\n    def __init__(self, gamma=2, alpha=0.25):\n        super(FocalLoss, self).__init__()\n        self.alpha = alpha\n        self.gamma = gamma\n        print('Initializing FocalLoss for training: alpha={}, gamma={}'.format(self.alpha, self.gamma))\n\n    def forward(self, input, target):\n        assert input.dim() == 2\n        assert not target.requires_grad\n        target = target.squeeze(1) if target.dim() == 2 else target\n        assert target.dim() == 1\n\n        logpt = F.log_softmax(input, dim=1)\n        logpt_gt = logpt.gather(1,target.unsqueeze(1))\n        logpt_gt = logpt_gt.view(-1)\n        pt_gt = logpt_gt.exp()\n        assert logpt_gt.size() == pt_gt.size()\n        \n        loss = -self.alpha*(torch.pow((1-pt_gt), self.gamma))*logpt_gt\n        \n        return loss.mean()\n\n\nclass LabelRefineLoss(nn.Module):\n    def __init__(self, lambda1=0.0):\n        super(LabelRefineLoss, self).__init__()\n        self.lambda1 = lambda1\n        print('Initializing LabelRefineLoss for training: lambda1={}'.format(self.lambda1))\n            \n    def forward(self, input, target):\n        assert input.dim() == 2\n        assert not target.requires_grad\n        target = target.squeeze(1) if target.dim() == 2 else target\n        assert target.dim() == 1\n\n        logpt = F.log_softmax(input, dim=1)\n        logpt_gt = logpt.gather(1,target.unsqueeze(1))\n        logpt_gt = logpt_gt.view(-1)\n        logpt_pred,_ = torch.max(logpt,1)\n        logpt_pred = logpt_pred.view(-1)\n        assert logpt_gt.size() == logpt_pred.size()\n        loss = - (1-self.lambda1)*logpt_gt - self.lambda1* logpt_pred\n        \n        return loss.mean()\n\n\nclass FocalTopLoss(nn.Module):\n    def __init__(self, top_percent=0.7):\n        super(FocalTopLoss, self).__init__()\n        self.top_percent = top_percent\n\n    def masked_softmax_multi_focal(self, vec, targets=None, dim=1):\n        exps = torch.exp(vec)\n        one_hot_pos = F.one_hot(targets, num_classes=exps.shape[1])\n\n        one_hot_neg = one_hot_pos.new_ones(size=one_hot_pos.shape)\n        one_hot_neg = one_hot_neg - one_hot_pos\n        \n        neg_exps = exps.new_zeros(size=exps.shape)\n        neg_exps[one_hot_neg>0] = exps[one_hot_neg>0]\n        ori_neg_exps = neg_exps\n        neg_exps = neg_exps/neg_exps.sum(dim=1, keepdim=True)\n        \n        new_exps = exps.new_zeros(size=exps.shape)\n        new_exps[one_hot_pos>0] = exps[one_hot_pos>0]\n\n        sorted, indices = torch.sort(neg_exps, dim=1, descending=True)\n        sorted_cum_sum = torch.cumsum(sorted, dim=1)\n        sorted_cum_diff = (sorted_cum_sum - self.top_percent).abs()\n        sorted_cum_min_indices = sorted_cum_diff.argmin(dim=1)\n        \n        min_values = sorted[torch.range(0, sorted.shape[0]-1).long(), sorted_cum_min_indices]\n        min_values = min_values.unsqueeze(dim=-1) * ori_neg_exps.sum(dim=1, keepdim=True)\n        ori_neg_exps[ori_neg_exps<min_values] = 0\n\n        new_exps[one_hot_neg>0] = ori_neg_exps[one_hot_neg>0]\n\n        masked_sums = exps.sum(dim, keepdim=True)\n        return new_exps / masked_sums\n\n    def forward(self, input, target):\n        masked_sim = self.masked_softmax_multi_focal(input, target)\n        return F.nll_loss(torch.log(masked_sim + 1e-6), target)"
  },
  {
    "path": "hhcl/models/pooling.py",
    "content": "# Credit to https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/layers/pooling.py\nfrom abc import ABC\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n__all__ = [\n    \"GeneralizedMeanPoolingPFpn\",\n    \"GeneralizedMeanPoolingList\",\n    \"GeneralizedMeanPoolingP\",\n    \"AdaptiveAvgMaxPool2d\",\n    \"FastGlobalAvgPool2d\",\n    \"avg_pooling\",\n    \"max_pooling\",\n]\n\n\nclass GeneralizedMeanPoolingList(nn.Module, ABC):\n    r\"\"\"Applies a 2D power-average adaptive pooling over an input signal composed of\n    several input planes.\n    The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`\n        - At p = infinity, one gets Max Pooling\n        - At p = 1, one gets Average Pooling\n    The output is of size H x W, for any input size.\n    The number of output features is equal to the number of input planes.\n    Args:\n        output_size: the target output size of the image of the form H x W.\n                     Can be a tuple (H, W) or a single H for a square image H x H\n                     H and W can be either a ``int``, or ``None`` which means the size\n                     will be the same as that of the input.\n    \"\"\"\n\n    def __init__(self, output_size=1, eps=1e-6):\n        super(GeneralizedMeanPoolingList, self).__init__()\n        self.output_size = output_size\n        self.eps = eps\n\n    def forward(self, x_list):\n        outs = []\n        for x in x_list:\n            x = x.clamp(min=self.eps)\n            out = torch.nn.functional.adaptive_avg_pool2d(x, self.output_size)\n            outs.append(out)\n        return torch.stack(outs, -1).mean(-1)\n\n    def __repr__(self):\n        return (\n            self.__class__.__name__\n            + \"(\"\n            + \"output_size=\"\n            + str(self.output_size)\n            + \")\"\n        )\n\n\nclass GeneralizedMeanPooling(nn.Module, ABC):\n    r\"\"\"Applies a 2D power-average adaptive pooling over an input signal composed of\n    several input planes.\n    The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`\n        - At p = infinity, one gets Max Pooling\n        - At p = 1, one gets Average Pooling\n    The output is of size H x W, for any input size.\n    The number of output features is equal to the number of input planes.\n    Args:\n        output_size: the target output size of the image of the form H x W.\n                     Can be a tuple (H, W) or a single H for a square image H x H\n                     H and W can be either a ``int``, or ``None`` which means the size\n                     will be the same as that of the input.\n    \"\"\"\n\n    def __init__(self, norm, output_size=1, eps=1e-6):\n        super(GeneralizedMeanPooling, self).__init__()\n        assert norm > 0\n        self.p = float(norm)\n        self.output_size = output_size\n        self.eps = eps\n\n    def forward(self, x):\n        x = x.clamp(min=self.eps).pow(self.p)\n        return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(\n            1.0 / self.p\n        )\n\n    def __repr__(self):\n        return (\n            self.__class__.__name__\n            + \"(\"\n            + str(self.p)\n            + \", \"\n            + \"output_size=\"\n            + str(self.output_size)\n            + \")\"\n        )\n\n\nclass GeneralizedMeanPoolingP(GeneralizedMeanPooling, ABC):\n    \"\"\" Same, but norm is trainable\n    \"\"\"\n\n    def __init__(self, norm=3, output_size=1, eps=1e-6):\n        super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)\n        self.p = nn.Parameter(torch.ones(1) * norm)\n\n\nclass GeneralizedMeanPoolingFpn(nn.Module, ABC):\n    r\"\"\"Applies a 2D power-average adaptive pooling over an input signal composed of\n    several input planes.\n    The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`\n        - At p = infinity, one gets Max Pooling\n        - At p = 1, one gets Average Pooling\n    The output is of size H x W, for any input size.\n    The number of output features is equal to the number of input planes.\n    Args:\n        output_size: the target output size of the image of the form H x W.\n                     Can be a tuple (H, W) or a single H for a square image H x H\n                     H and W can be either a ``int``, or ``None`` which means the size\n                     will be the same as that of the input.\n    \"\"\"\n\n    def __init__(self, norm, output_size=1, eps=1e-6):\n        super(GeneralizedMeanPoolingFpn, self).__init__()\n        assert norm > 0\n        self.p = float(norm)\n        self.output_size = output_size\n        self.eps = eps\n\n    def forward(self, x_lists):\n        outs = []\n        for x in x_lists:\n            x = x.clamp(min=self.eps).pow(self.p)\n            out = torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(\n                1.0 / self.p\n            )\n            outs.append(out)\n        return torch.cat(outs, 1)\n\n    def __repr__(self):\n        return (\n            self.__class__.__name__\n            + \"(\"\n            + str(self.p)\n            + \", \"\n            + \"output_size=\"\n            + str(self.output_size)\n            + \")\"\n        )\n\n\nclass GeneralizedMeanPoolingPFpn(GeneralizedMeanPoolingFpn, ABC):\n    \"\"\" Same, but norm is trainable\n    \"\"\"\n\n    def __init__(self, norm=3, output_size=1, eps=1e-6):\n        super(GeneralizedMeanPoolingPFpn, self).__init__(norm, output_size, eps)\n        self.p = nn.Parameter(torch.ones(1) * norm)\n\n\nclass AdaptiveAvgMaxPool2d(nn.Module, ABC):\n    def __init__(self):\n        super(AdaptiveAvgMaxPool2d, self).__init__()\n        self.avgpool = FastGlobalAvgPool2d()\n\n    def forward(self, x):\n        x_avg = self.avgpool(x, self.output_size)\n        x_max = F.adaptive_max_pool2d(x, 1)\n        x = x_max + x_avg\n        return x\n\n\nclass FastGlobalAvgPool2d(nn.Module, ABC):\n    def __init__(self, flatten=False):\n        super(FastGlobalAvgPool2d, self).__init__()\n        self.flatten = flatten\n\n    def forward(self, x):\n        if self.flatten:\n            in_size = x.size()\n            return x.view((in_size[0], in_size[1], -1)).mean(dim=2)\n        else:\n            return (\n                x.view(x.size(0), x.size(1), -1)\n                .mean(-1)\n                .view(x.size(0), x.size(1), 1, 1)\n            )\n\n\ndef avg_pooling():\n    return nn.AdaptiveAvgPool2d(1)\n    # return FastGlobalAvgPool2d()\n\n\ndef max_pooling():\n    return nn.AdaptiveMaxPool2d(1)\n\n\nclass Flatten(nn.Module):\n    def forward(self, input):\n        return input.view(input.size(0), -1)\n\n\n__pooling_factory = {\n    \"avg\": avg_pooling,\n    \"max\": max_pooling,\n    \"gem\": GeneralizedMeanPoolingP,\n    \"gemFpn\": GeneralizedMeanPoolingPFpn,\n    \"gemList\": GeneralizedMeanPoolingList,\n    \"avg+max\": AdaptiveAvgMaxPool2d,\n}\n\n\ndef pooling_names():\n    return sorted(__pooling_factory.keys())\n\n\ndef build_pooling_layer(name):\n    \"\"\"\n    Create a pooling layer.\n    Parameters\n    ----------\n    name : str\n        The backbone name.\n    \"\"\"\n    if name not in __pooling_factory:\n        raise KeyError(\"Unknown pooling layer:\", name)\n    return __pooling_factory[name]()"
  },
  {
    "path": "hhcl/models/resnet.py",
    "content": "from __future__ import absolute_import\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.nn import init\nimport torchvision\nimport torch\nfrom .pooling import build_pooling_layer\n\n\n__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n           'resnet152']\n\n\n\nclass ResNet(nn.Module):\n    __factory = {\n        18: torchvision.models.resnet18,\n        34: torchvision.models.resnet34,\n        50: torchvision.models.resnet50,\n        101: torchvision.models.resnet101,\n        152: torchvision.models.resnet152,\n    }\n\n    def __init__(self, depth, pretrained=True, cut_at_pooling=False,\n                 num_features=0, norm=False, dropout=0, num_classes=0, pooling_type='avg'):\n        print('pooling_type: {}'.format(pooling_type))\n        super(ResNet, self).__init__()\n        self.pretrained = pretrained\n        self.depth = depth\n        self.cut_at_pooling = cut_at_pooling\n        # Construct base (pretrained) resnet\n        if depth not in ResNet.__factory:\n            raise KeyError(\"Unsupported depth:\", depth)\n        resnet = ResNet.__factory[depth](pretrained=pretrained)\n        if self.depth >= 50:\n            resnet.layer4[0].conv2.stride = (1, 1)\n            resnet.layer4[0].downsample[0].stride = (1, 1)\n        self.base = nn.Sequential(\n            resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,\n            resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4)\n\n        self.gap = build_pooling_layer(pooling_type)\n\n        if not self.cut_at_pooling:\n            self.num_features = num_features\n            self.norm = norm\n            self.dropout = dropout\n            self.has_embedding = num_features > 0\n            self.num_classes = num_classes\n\n            out_planes = resnet.fc.in_features\n\n            # Append new layers\n            if self.has_embedding:\n                self.feat = nn.Linear(out_planes, self.num_features)\n                self.feat_bn = nn.BatchNorm1d(self.num_features)\n                init.kaiming_normal_(self.feat.weight, mode='fan_out')\n                init.constant_(self.feat.bias, 0)\n            else:\n                # Change the num_features to CNN output channels\n                self.num_features = out_planes\n                self.feat_bn = nn.BatchNorm1d(self.num_features)\n            self.feat_bn.bias.requires_grad_(False)\n            if self.dropout > 0:\n                self.drop = nn.Dropout(self.dropout)\n            if self.num_classes > 0:\n                self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False)\n                init.normal_(self.classifier.weight, std=0.001)\n        init.constant_(self.feat_bn.weight, 1)\n        init.constant_(self.feat_bn.bias, 0)\n\n        if not pretrained:\n            self.reset_params()\n\n    def forward(self, x):\n        bs = x.size(0)\n        x = self.base(x)\n\n        x = self.gap(x)\n        x = x.view(x.size(0), -1)\n\n        if self.cut_at_pooling:\n            return x\n\n        if self.has_embedding:\n            bn_x = self.feat_bn(self.feat(x))\n        else:\n            bn_x = self.feat_bn(x)\n\n        if (self.training is False):\n            bn_x = F.normalize(bn_x)\n            return bn_x\n\n        if self.norm:\n            bn_x = F.normalize(bn_x)\n        elif self.has_embedding:\n            bn_x = F.relu(bn_x)\n\n        if self.dropout > 0:\n            bn_x = self.drop(bn_x)\n\n        if self.num_classes > 0:\n            prob = self.classifier(bn_x)\n        else:\n            return bn_x\n\n        return prob, bn_x\n\n    def reset_params(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm1d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n\ndef resnet18(**kwargs):\n    return ResNet(18, **kwargs)\n\n\ndef resnet34(**kwargs):\n    return ResNet(34, **kwargs)\n\n\ndef resnet50(**kwargs):\n    return ResNet(50, **kwargs)\n\n\ndef resnet101(**kwargs):\n    return ResNet(101, **kwargs)\n\n\ndef resnet152(**kwargs):\n    return ResNet(152, **kwargs)\n"
  },
  {
    "path": "hhcl/models/resnet_ibn.py",
    "content": "from __future__ import absolute_import\n\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.nn import init\nimport torchvision\nimport torch\nfrom .pooling import build_pooling_layer\n\nfrom .resnet_ibn_a import resnet50_ibn_a, resnet101_ibn_a\n\n\n__all__ = ['ResNetIBN', 'resnet_ibn50a', 'resnet_ibn101a']\n\n\nclass ResNetIBN(nn.Module):\n    __factory = {\n        '50a': resnet50_ibn_a,\n        '101a': resnet101_ibn_a\n    }\n\n    def __init__(self, depth, pretrained=True, cut_at_pooling=False,\n                 num_features=0, norm=False, dropout=0, num_classes=0, pooling_type='avg'):\n\n        print('pooling_type: {}'.format(pooling_type))\n        super(ResNetIBN, self).__init__()\n\n        self.depth = depth\n        self.pretrained = pretrained\n        self.cut_at_pooling = cut_at_pooling\n\n        resnet = ResNetIBN.__factory[depth](pretrained=pretrained)\n        resnet.layer4[0].conv2.stride = (1, 1)\n        resnet.layer4[0].downsample[0].stride = (1, 1)\n\n        self.base = nn.Sequential(\n            resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool,\n            resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4)\n\n        self.gap = build_pooling_layer(pooling_type)\n\n        if not self.cut_at_pooling:\n            self.num_features = num_features\n            self.norm = norm\n            self.dropout = dropout\n            self.has_embedding = num_features > 0\n            self.num_classes = num_classes\n\n            out_planes = resnet.fc.in_features\n\n            # Append new layers\n            if self.has_embedding:\n                self.feat = nn.Linear(out_planes, self.num_features)\n                self.feat_bn = nn.BatchNorm1d(self.num_features)\n                init.kaiming_normal_(self.feat.weight, mode='fan_out')\n                init.constant_(self.feat.bias, 0)\n            else:\n                # Change the num_features to CNN output channels\n                self.num_features = out_planes\n                self.feat_bn = nn.BatchNorm1d(self.num_features)\n            self.feat_bn.bias.requires_grad_(False)\n            if self.dropout > 0:\n                self.drop = nn.Dropout(self.dropout)\n            if self.num_classes > 0:\n                self.classifier = nn.Linear(self.num_features, self.num_classes, bias=False)\n                init.normal_(self.classifier.weight, std=0.001)\n\n        init.constant_(self.feat_bn.weight, 1)\n        init.constant_(self.feat_bn.bias, 0)\n\n        if not pretrained:\n            self.reset_params()\n\n    def forward(self, x):\n        x = self.base(x)\n\n        x = self.gap(x)\n        x = x.view(x.size(0), -1)\n\n        if self.cut_at_pooling:\n            return x\n\n        if self.has_embedding:\n            bn_x = self.feat_bn(self.feat(x))\n        else:\n            bn_x = self.feat_bn(x)\n\n        if self.training is False:\n            bn_x = F.normalize(bn_x)\n            return bn_x\n\n        if self.norm:\n            bn_x = F.normalize(bn_x)\n        elif self.has_embedding:\n            bn_x = F.relu(bn_x)\n\n        if self.dropout > 0:\n            bn_x = self.drop(bn_x)\n\n        if self.num_classes > 0:\n            prob = self.classifier(bn_x)\n        else:\n            return bn_x\n\n        return prob, bn_x\n\n    def reset_params(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, mode='fan_out')\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.BatchNorm1d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias, 0)\n            elif isinstance(m, nn.Linear):\n                init.normal_(m.weight, std=0.001)\n                if m.bias is not None:\n                    init.constant_(m.bias, 0)\n\n\ndef resnet_ibn50a(**kwargs):\n    return ResNetIBN('50a', **kwargs)\n\n\ndef resnet_ibn101a(**kwargs):\n    return ResNetIBN('101a', **kwargs)\n"
  },
  {
    "path": "hhcl/models/resnet_ibn_a.py",
    "content": "import torch\nimport torch.nn as nn\nimport math\nimport torch.utils.model_zoo as model_zoo\n\n\n__all__ = ['ResNet', 'resnet50_ibn_a', 'resnet101_ibn_a']\n\n\nmodel_urls = {\n    'ibn_resnet50a': './examples/pretrained/resnet50_ibn_a.pth.tar',\n    'ibn_resnet101a': './examples/pretrained/resnet101_ibn_a.pth.tar',\n}\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=1, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(BasicBlock, self).__init__()\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass IBN(nn.Module):\n    def __init__(self, planes):\n        super(IBN, self).__init__()\n        half1 = int(planes/2)\n        self.half = half1\n        half2 = planes - half1\n        self.IN = nn.InstanceNorm2d(half1, affine=True)\n        self.BN = nn.BatchNorm2d(half2)\n\n    def forward(self, x):\n        split = torch.split(x, self.half, 1)\n        out1 = self.IN(split[0].contiguous())\n        out2 = self.BN(split[1].contiguous())\n        out = torch.cat((out1, out2), 1)\n        return out\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        if ibn:\n            self.bn1 = IBN(planes)\n        else:\n            self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,\n                               padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n\n    def __init__(self, block, layers, num_classes=1000):\n        scale = 64\n        self.inplanes = scale\n        super(ResNet, self).__init__()\n        self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = nn.BatchNorm2d(scale)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, scale, layers[0])\n        self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, scale*8, layers[3], stride=2)\n        self.avgpool = nn.AvgPool2d(7)\n        self.fc = nn.Linear(scale * 8 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2. / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n            elif isinstance(m, nn.InstanceNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion,\n                          kernel_size=1, stride=stride, bias=False),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        ibn = True\n        if planes == 512:\n            ibn = False\n        layers.append(block(self.inplanes, planes, ibn, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for i in range(1, blocks):\n            layers.append(block(self.inplanes, planes, ibn))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n\n        x = self.avgpool(x)\n        x = x.view(x.size(0), -1)\n        x = self.fc(x)\n\n        return x\n\n\ndef resnet50_ibn_a(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-50 model.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)\n    if pretrained:\n        state_dict = torch.load(model_urls['ibn_resnet50a'], map_location=torch.device('cpu'))['state_dict']\n        state_dict = remove_module_key(state_dict)\n        model.load_state_dict(state_dict)\n    return model\n\n\ndef resnet101_ibn_a(pretrained=False, **kwargs):\n    \"\"\"Constructs a ResNet-101 model.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)\n    if pretrained:\n        state_dict = torch.load(model_urls['ibn_resnet101a'], map_location=torch.device('cpu'))['state_dict']\n        state_dict = remove_module_key(state_dict)\n        model.load_state_dict(state_dict)\n    return model\n\n\ndef remove_module_key(state_dict):\n    for key in list(state_dict.keys()):\n        if 'module' in key:\n            state_dict[key.replace('module.','')] = state_dict.pop(key)\n    return state_dict\n"
  },
  {
    "path": "hhcl/models/triplet.py",
    "content": "from __future__ import absolute_import\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n\ndef euclidean_dist(x, y):\n\tm, n = x.size(0), y.size(0)\n\txx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)\n\tyy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()\n\tdist = xx + yy\n\tdist.addmm_(1, -2, x, y.t())\n\tdist = dist.clamp(min=1e-12).sqrt()  # for numerical stability\n\treturn dist\n\ndef cosine_dist(x, y):\n\tbs1, bs2 = x.size(0), y.size(0)\n\tfrac_up = torch.matmul(x, y.transpose(0, 1))\n\tfrac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \\\n\t            (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)\n\tcosine = frac_up / frac_down\n\treturn 1-cosine\n\ndef _batch_hard(mat_distance, mat_similarity, indice=False):\n\tsorted_mat_distance, positive_indices = torch.sort(mat_distance + (-9999999.) * (1 - mat_similarity), dim=1, descending=True)\n\thard_p = sorted_mat_distance[:, 0]\n\thard_p_indice = positive_indices[:, 0]\n\tsorted_mat_distance, negative_indices = torch.sort(mat_distance + (9999999.) * (mat_similarity), dim=1, descending=False)\n\thard_n = sorted_mat_distance[:, 0]\n\thard_n_indice = negative_indices[:, 0]\n\tif(indice):\n\t\treturn hard_p, hard_n, hard_p_indice, hard_n_indice\n\treturn hard_p, hard_n\n\nclass TripletLoss(nn.Module):\n\t'''\n\tCompute Triplet loss augmented with Batch Hard\n\tDetails can be seen in 'In defense of the Triplet Loss for Person Re-Identification'\n\t'''\n\n\tdef __init__(self, margin, normalize_feature=False):\n\t\tsuper(TripletLoss, self).__init__()\n\t\tself.margin = margin\n\t\tself.normalize_feature = normalize_feature\n\t\tself.margin_loss = nn.MarginRankingLoss(margin=margin).cuda()\n\n\tdef forward(self, emb, label):\n\t\tif self.normalize_feature:\n\t\t\t# equal to cosine similarity\n\t\t\temb = F.normalize(emb)\n\t\tmat_dist = euclidean_dist(emb, emb)\n\t\t# mat_dist = cosine_dist(emb, emb)\n\t\tassert mat_dist.size(0) == mat_dist.size(1)\n\t\tN = mat_dist.size(0)\n\t\tmat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float()\n\n\t\tdist_ap, dist_an = _batch_hard(mat_dist, mat_sim)\n\t\tassert dist_an.size(0)==dist_ap.size(0)\n\t\ty = torch.ones_like(dist_ap)\n\t\tloss = self.margin_loss(dist_an, dist_ap, y)\n\t\tprec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)\n\t\treturn loss, prec\n\nclass SoftTripletLoss(nn.Module):\n\n\tdef __init__(self, margin=None, normalize_feature=False):\n\t\tsuper(SoftTripletLoss, self).__init__()\n\t\tself.margin = margin\n\t\tself.normalize_feature = normalize_feature\n\n\tdef forward(self, emb1, emb2, label):\n\t\tif self.normalize_feature:\n\t\t\t# equal to cosine similarity\n\t\t\temb1 = F.normalize(emb1)\n\t\t\temb2 = F.normalize(emb2)\n\n\t\tmat_dist = euclidean_dist(emb1, emb1)\n\t\tassert mat_dist.size(0) == mat_dist.size(1)\n\t\tN = mat_dist.size(0)\n\t\tmat_sim = label.expand(N, N).eq(label.expand(N, N).t()).float()\n\n\t\tdist_ap, dist_an, ap_idx, an_idx = _batch_hard(mat_dist, mat_sim, indice=True)\n\t\tassert dist_an.size(0)==dist_ap.size(0)\n\t\ttriple_dist = torch.stack((dist_ap, dist_an), dim=1)\n\t\ttriple_dist = F.log_softmax(triple_dist, dim=1)\n\t\tif (self.margin is not None):\n\t\t\tloss = (- self.margin * triple_dist[:,0] - (1 - self.margin) * triple_dist[:,1]).mean()\n\t\t\treturn loss\n\n\t\tmat_dist_ref = euclidean_dist(emb2, emb2)\n\t\tdist_ap_ref = torch.gather(mat_dist_ref, 1, ap_idx.view(N,1).expand(N,N))[:,0]\n\t\tdist_an_ref = torch.gather(mat_dist_ref, 1, an_idx.view(N,1).expand(N,N))[:,0]\n\t\ttriple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1)\n\t\ttriple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach()\n\n\t\tloss = (- triple_dist_ref * triple_dist).mean(0).sum()\n\t\treturn loss"
  },
  {
    "path": "hhcl/trainers.py",
    "content": "from __future__ import print_function, absolute_import\nimport time\nimport torch\nimport torch.nn.functional as F\nfrom .utils.meters import AverageMeter\n\n\nclass Trainer(object):\n    def __init__(self, encoder, memory=None):\n        super(Trainer, self).__init__()\n        self.encoder = encoder\n        self.memory = memory\n\n    def train(self, epoch, data_loader, optimizer, print_freq=10, train_iters=400):\n        self.encoder.train()\n\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n\n        losses = AverageMeter()\n\n        end = time.time()\n        for i in range(train_iters):\n            # load data\n            inputs = data_loader.next()\n            data_time.update(time.time() - end)\n\n            # process inputs\n            inputs, labels, indexes = self._parse_data(inputs)\n\n            loss = 0\n            # forward\n            f_out = self._forward(inputs)\n            loss += self.memory(f_out, labels)\n\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            losses.update(loss.item())\n\n            # print log\n            batch_time.update(time.time() - end)\n            end = time.time()\n\n            if (i + 1) % print_freq == 0:\n                print('Epoch: [{}][{}/{}]\\t'\n                      'Time {:.3f} ({:.3f})\\t'\n                      'Data {:.3f} ({:.3f})\\t'\n                      'Loss {:.3f} ({:.3f})'\n                      .format(epoch, i + 1, len(data_loader),\n                              batch_time.val, batch_time.avg,\n                              data_time.val, data_time.avg,\n                              losses.val, losses.avg))\n\n    def _parse_data(self, inputs):\n        imgs, _, pids, _, indexes = inputs\n        return imgs.cuda(), pids.cuda(), indexes.cuda()\n\n    def _forward(self, inputs):\n        return self.encoder(inputs)"
  },
  {
    "path": "hhcl/utils/__init__.py",
    "content": "from __future__ import absolute_import\n\nimport torch\n\n\ndef to_numpy(tensor):\n    if torch.is_tensor(tensor):\n        return tensor.cpu().numpy()\n    elif type(tensor).__module__ != 'numpy':\n        raise ValueError(\"Cannot convert {} to numpy array\"\n                         .format(type(tensor)))\n    return tensor\n\n\ndef to_torch(ndarray):\n    if type(ndarray).__module__ == 'numpy':\n        return torch.from_numpy(ndarray)\n    elif not torch.is_tensor(ndarray):\n        raise ValueError(\"Cannot convert {} to torch tensor\"\n                         .format(type(ndarray)))\n    return ndarray\n"
  },
  {
    "path": "hhcl/utils/data/__init__.py",
    "content": "from __future__ import absolute_import\n\nfrom .base_dataset import BaseDataset, BaseImageDataset\nfrom .preprocessor import Preprocessor\n\n\nclass IterLoader:\n    def __init__(self, loader, length=None):\n        self.loader = loader\n        self.length = length\n        self.iter = None\n\n    def __len__(self):\n        if self.length is not None:\n            return self.length\n\n        return len(self.loader)\n\n    def new_epoch(self):\n        self.iter = iter(self.loader)\n\n    def next(self):\n        try:\n            return next(self.iter)\n        except:\n            self.iter = iter(self.loader)\n            return next(self.iter)\n"
  },
  {
    "path": "hhcl/utils/data/base_dataset.py",
    "content": "# encoding: utf-8\nimport numpy as np\n\n\nclass BaseDataset(object):\n    \"\"\"\n    Base class of reid dataset\n    \"\"\"\n\n    def get_imagedata_info(self, data):\n        pids, cams = [], []\n        for _, pid, camid in data:\n            pids += [pid]\n            cams += [camid]\n        pids = set(pids)\n        cams = set(cams)\n        num_pids = len(pids)\n        num_cams = len(cams)\n        num_imgs = len(data)\n        return num_pids, num_imgs, num_cams\n\n    def print_dataset_statistics(self):\n        raise NotImplementedError\n\n    @property\n    def images_dir(self):\n        return None\n\n\nclass BaseImageDataset(BaseDataset):\n    \"\"\"\n    Base class of image reid dataset\n    \"\"\"\n\n    def print_dataset_statistics(self, train, query, gallery):\n        num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)\n        num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)\n        num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)\n\n        print(\"Dataset statistics:\")\n        print(\"  ----------------------------------------\")\n        print(\"  subset   | # ids | # images | # cameras\")\n        print(\"  ----------------------------------------\")\n        print(\"  train    | {:5d} | {:8d} | {:9d}\".format(num_train_pids, num_train_imgs, num_train_cams))\n        print(\"  query    | {:5d} | {:8d} | {:9d}\".format(num_query_pids, num_query_imgs, num_query_cams))\n        print(\"  gallery  | {:5d} | {:8d} | {:9d}\".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))\n        print(\"  ----------------------------------------\")\n"
  },
  {
    "path": "hhcl/utils/data/preprocessor.py",
    "content": "from __future__ import absolute_import\nimport os\nimport os.path as osp\nfrom torch.utils.data import DataLoader, Dataset\nimport numpy as np\nimport random\nimport math\nfrom PIL import Image\n\n\nclass Preprocessor(Dataset):\n    def __init__(self, dataset, root=None, transform=None, mutual=False):\n        super(Preprocessor, self).__init__()\n        self.dataset = dataset\n        self.root = root\n        self.transform = transform\n        self.mutual = mutual\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, indices):\n        if self.mutual:\n            return self._get_mutual_item(indices)\n        else:\n            return self._get_single_item(indices)\n\n    def _get_single_item(self, index):\n        fname, pid, camid = self.dataset[index]\n        fpath = fname\n        if self.root is not None:\n            fpath = osp.join(self.root, fname)\n\n        img = Image.open(fpath).convert('RGB')\n\n        if self.transform is not None:\n            img = self.transform(img)\n\n        return img, fname, pid, camid, index\n\n    def _get_mutual_item(self, index):\n        fname, pid, camid = self.dataset[index]\n        fpath = fname\n        if self.root is not None:\n            fpath = osp.join(self.root, fname)\n\n        img_1 = Image.open(fpath).convert('RGB')\n        img_2 = img_1.copy()\n\n        if self.transform is not None:\n            img_1 = self.transform(img_1)\n            img_2 = self.transform(img_2)\n\n        return img_1, img_2, pid, camid"
  },
  {
    "path": "hhcl/utils/data/sampler.py",
    "content": "from __future__ import absolute_import\nfrom collections import defaultdict\nimport math\n\nimport numpy as np\nimport copy\nimport random\nimport torch\nfrom torch.utils.data.sampler import (\n    Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler,\n    WeightedRandomSampler)\n\n\ndef No_index(a, b):\n    assert isinstance(a, list)\n    return [i for i, j in enumerate(a) if j != b]\n\n\nclass RandomIdentitySampler(Sampler):\n    def __init__(self, data_source, num_instances):\n        self.data_source = data_source\n        self.num_instances = num_instances\n        self.index_dic = defaultdict(list)\n        for index, (_, pid, _) in enumerate(data_source):\n            self.index_dic[pid].append(index)\n        self.pids = list(self.index_dic.keys())\n        self.num_samples = len(self.pids)\n\n    def __len__(self):\n        return self.num_samples * self.num_instances\n\n    def __iter__(self):\n        indices = torch.randperm(self.num_samples).tolist()\n        ret = []\n        for i in indices:\n            pid = self.pids[i]\n            t = self.index_dic[pid]\n            if len(t) >= self.num_instances:\n                t = np.random.choice(t, size=self.num_instances, replace=False)\n            else:\n                t = np.random.choice(t, size=self.num_instances, replace=True)\n            ret.extend(t)\n        return iter(ret)\n\n\nclass RandomMultipleGallerySampler(Sampler):\n    def __init__(self, data_source, num_instances=4):\n        super().__init__(data_source)\n        self.data_source = data_source\n        self.index_pid = defaultdict(int)\n        self.pid_cam = defaultdict(list)\n        self.pid_index = defaultdict(list)\n        self.num_instances = num_instances\n\n        for index, (_, pid, cam) in enumerate(data_source):\n            if pid < 0:\n                continue\n            self.index_pid[index] = pid\n            self.pid_cam[pid].append(cam)\n            self.pid_index[pid].append(index)\n\n        self.pids = list(self.pid_index.keys())\n        self.num_samples = len(self.pids)\n\n    def __len__(self):\n        return self.num_samples * self.num_instances\n\n    def __iter__(self):\n        indices = torch.randperm(len(self.pids)).tolist()\n        ret = []\n\n        for kid in indices:\n            i = random.choice(self.pid_index[self.pids[kid]])\n\n            _, i_pid, i_cam = self.data_source[i]\n\n            ret.append(i)\n\n            pid_i = self.index_pid[i]\n            cams = self.pid_cam[pid_i]\n            index = self.pid_index[pid_i]\n            select_cams = No_index(cams, i_cam)\n\n            if select_cams:\n\n                if len(select_cams) >= self.num_instances:\n                    cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False)\n                else:\n                    cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True)\n\n                for kk in cam_indexes:\n                    ret.append(index[kk])\n\n            else:\n                select_indexes = No_index(index, i)\n                if not select_indexes:\n                    continue\n                if len(select_indexes) >= self.num_instances:\n                    ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False)\n                else:\n                    ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True)\n\n                for kk in ind_indexes:\n                    ret.append(index[kk])\n\n        return iter(ret)\n"
  },
  {
    "path": "hhcl/utils/data/transforms.py",
    "content": "from __future__ import absolute_import\n\nfrom torchvision.transforms import *\nfrom PIL import Image\nimport random\nimport math\nimport numpy as np\n\nclass RectScale(object):\n    def __init__(self, height, width, interpolation=Image.BILINEAR):\n        self.height = height\n        self.width = width\n        self.interpolation = interpolation\n\n    def __call__(self, img):\n        w, h = img.size\n        if h == self.height and w == self.width:\n            return img\n        return img.resize((self.width, self.height), self.interpolation)\n\n\nclass RandomSizedRectCrop(object):\n    def __init__(self, height, width, interpolation=Image.BILINEAR):\n        self.height = height\n        self.width = width\n        self.interpolation = interpolation\n\n    def __call__(self, img):\n        for attempt in range(10):\n            area = img.size[0] * img.size[1]\n            target_area = random.uniform(0.64, 1.0) * area\n            aspect_ratio = random.uniform(2, 3)\n\n            h = int(round(math.sqrt(target_area * aspect_ratio)))\n            w = int(round(math.sqrt(target_area / aspect_ratio)))\n\n            if w <= img.size[0] and h <= img.size[1]:\n                x1 = random.randint(0, img.size[0] - w)\n                y1 = random.randint(0, img.size[1] - h)\n\n                img = img.crop((x1, y1, x1 + w, y1 + h))\n                assert(img.size == (w, h))\n\n                return img.resize((self.width, self.height), self.interpolation)\n\n        # Fallback\n        scale = RectScale(self.height, self.width,\n                          interpolation=self.interpolation)\n        return scale(img)\n\n\nclass RandomErasing(object):\n    \"\"\" Randomly selects a rectangle region in an image and erases its pixels.\n        'Random Erasing Data Augmentation' by Zhong et al.\n        See https://arxiv.org/pdf/1708.04896.pdf\n    Args:\n         probability: The probability that the Random Erasing operation will be performed.\n         sl: Minimum proportion of erased area against input image.\n         sh: Maximum proportion of erased area against input image.\n         r1: Minimum aspect ratio of erased area.\n         mean: Erasing value.\n    \"\"\"\n\n    def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):\n        self.probability = probability\n        self.mean = mean\n        self.sl = sl\n        self.sh = sh\n        self.r1 = r1\n\n    def __call__(self, img):\n\n        if random.uniform(0, 1) >= self.probability:\n            return img\n\n        for attempt in range(100):\n            area = img.size()[1] * img.size()[2]\n\n            target_area = random.uniform(self.sl, self.sh) * area\n            aspect_ratio = random.uniform(self.r1, 1 / self.r1)\n\n            h = int(round(math.sqrt(target_area * aspect_ratio)))\n            w = int(round(math.sqrt(target_area / aspect_ratio)))\n\n            if w < img.size()[2] and h < img.size()[1]:\n                x1 = random.randint(0, img.size()[1] - h)\n                y1 = random.randint(0, img.size()[2] - w)\n                if img.size()[0] == 3:\n                    img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]\n                    img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]\n                    img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]\n                else:\n                    img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]\n                return img\n\n        return img\n"
  },
  {
    "path": "hhcl/utils/faiss_rerank.py",
    "content": "#!/usr/bin/env python3\n# -*- coding: utf-8 -*-\n\"\"\"\nCVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.\nurl:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf\nMatlab version: https://github.com/zhunzhong07/person-re-ranking\n\"\"\"\n\nimport os, sys\nimport time\nimport numpy as np\nfrom scipy.spatial.distance import cdist\nimport gc\nimport faiss\n\nimport torch\nimport torch.nn.functional as F\n\nfrom .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \\\n                            index_init_gpu, index_init_cpu\n\n\ndef k_reciprocal_neigh(initial_rank, i, k1):\n    forward_k_neigh_index = initial_rank[i,:k1+1]\n    backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]\n    fi = np.where(backward_k_neigh_index==i)[0]\n    return forward_k_neigh_index[fi]\n\n\ndef compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False):\n    end = time.time()\n    if print_flag:\n        print('Computing jaccard distance...')\n\n    ngpus = faiss.get_num_gpus()\n    N = target_features.size(0)\n    mat_type = np.float16 if use_float16 else np.float32\n\n    if (search_option==0):\n        # GPU + PyTorch CUDA Tensors (1)\n        res = faiss.StandardGpuResources()\n        res.setDefaultNullStreamAllDevices()\n        _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1)\n        initial_rank = initial_rank.cpu().numpy()\n    elif (search_option==1):\n        # GPU + PyTorch CUDA Tensors (2)\n        res = faiss.StandardGpuResources()\n        index = faiss.GpuIndexFlatL2(res, target_features.size(-1))\n        index.add(target_features.cpu().numpy())\n        _, initial_rank = search_index_pytorch(index, target_features, k1)\n        res.syncDefaultStreamCurrentDevice()\n        initial_rank = initial_rank.cpu().numpy()\n    elif (search_option==2):\n        # GPU\n        index = index_init_gpu(ngpus, target_features.size(-1))\n        index.add(target_features.cpu().numpy())\n        _, initial_rank = index.search(target_features.cpu().numpy(), k1)\n    else:\n        # CPU\n        index = index_init_cpu(target_features.size(-1))\n        index.add(target_features.cpu().numpy())\n        _, initial_rank = index.search(target_features.cpu().numpy(), k1)\n\n\n    nn_k1 = []\n    nn_k1_half = []\n    for i in range(N):\n        nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1))\n        nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1/2))))\n\n    V = np.zeros((N, N), dtype=mat_type)\n    for i in range(N):\n        k_reciprocal_index = nn_k1[i]\n        k_reciprocal_expansion_index = k_reciprocal_index\n        for candidate in k_reciprocal_index:\n            candidate_k_reciprocal_index = nn_k1_half[candidate]\n            if (len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index)) > 2/3*len(candidate_k_reciprocal_index)):\n                k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)\n\n        k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)  ## element-wise unique\n        dist = 2-2*torch.mm(target_features[i].unsqueeze(0).contiguous(), target_features[k_reciprocal_expansion_index].t())\n        if use_float16:\n            V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type)\n        else:\n            V[i,k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy()\n\n    del nn_k1, nn_k1_half\n\n    if k2 != 1:\n        V_qe = np.zeros_like(V, dtype=mat_type)\n        for i in range(N):\n            V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:], axis=0)\n        V = V_qe\n        del V_qe\n\n    del initial_rank\n\n    invIndex = []\n    for i in range(N):\n        invIndex.append(np.where(V[:,i] != 0)[0])  #len(invIndex)=all_num\n\n    jaccard_dist = np.zeros((N, N), dtype=mat_type)\n    for i in range(N):\n        temp_min = np.zeros((1, N), dtype=mat_type)\n        # temp_max = np.zeros((1,N), dtype=mat_type)\n        indNonZero = np.where(V[i, :] != 0)[0]\n        indImages = []\n        indImages = [invIndex[ind] for ind in indNonZero]\n        for j in range(len(indNonZero)):\n            temp_min[0, indImages[j]] = temp_min[0, indImages[j]]+np.minimum(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]])\n            # temp_max[0,indImages[j]] = temp_max[0,indImages[j]]+np.maximum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])\n\n        jaccard_dist[i] = 1-temp_min/(2-temp_min)\n        # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6)\n\n    del invIndex, V\n\n    pos_bool = (jaccard_dist < 0)\n    jaccard_dist[pos_bool] = 0.0\n    if print_flag:\n        print(\"Jaccard distance computing time cost: {}\".format(time.time()-end))\n\n    return jaccard_dist\n"
  },
  {
    "path": "hhcl/utils/faiss_utils.py",
    "content": "import os\nimport numpy as np\nimport faiss\nimport torch\n\ndef swig_ptr_from_FloatTensor(x):\n    assert x.is_contiguous()\n    assert x.dtype == torch.float32\n    return faiss.cast_integer_to_float_ptr(\n        x.storage().data_ptr() + x.storage_offset() * 4)\n\ndef swig_ptr_from_LongTensor(x):\n    assert x.is_contiguous()\n    assert x.dtype == torch.int64, 'dtype=%s' % x.dtype\n\n    return faiss.cast_integer_to_long_ptr(\n        x.storage().data_ptr() + x.storage_offset() * 8)\n\ndef search_index_pytorch(index, x, k, D=None, I=None):\n    \"\"\"call the search function of an index with pytorch tensor I/O (CPU\n    and GPU supported)\"\"\"\n    assert x.is_contiguous()\n    n, d = x.size()\n    assert d == index.d\n\n    if D is None:\n        D = torch.empty((n, k), dtype=torch.float32, device=x.device)\n    else:\n        assert D.size() == (n, k)\n\n    if I is None:\n        I = torch.empty((n, k), dtype=torch.int64, device=x.device)\n    else:\n        assert I.size() == (n, k)\n    torch.cuda.synchronize()\n    xptr = swig_ptr_from_FloatTensor(x)\n    Iptr = swig_ptr_from_LongTensor(I)\n    Dptr = swig_ptr_from_FloatTensor(D)\n    index.search_c(n, xptr,\n                   k, Dptr, Iptr)\n    torch.cuda.synchronize()\n    return D, I\n\ndef search_raw_array_pytorch(res, xb, xq, k, D=None, I=None,\n                             metric=faiss.METRIC_L2):\n    assert xb.device == xq.device\n\n    nq, d = xq.size()\n    if xq.is_contiguous():\n        xq_row_major = True\n    elif xq.t().is_contiguous():\n        xq = xq.t()    # I initially wrote xq:t(), Lua is still haunting me :-)\n        xq_row_major = False\n    else:\n        raise TypeError('matrix should be row or column-major')\n\n    xq_ptr = swig_ptr_from_FloatTensor(xq)\n\n    nb, d2 = xb.size()\n    assert d2 == d\n    if xb.is_contiguous():\n        xb_row_major = True\n    elif xb.t().is_contiguous():\n        xb = xb.t()\n        xb_row_major = False\n    else:\n        raise TypeError('matrix should be row or column-major')\n    xb_ptr = swig_ptr_from_FloatTensor(xb)\n\n    if D is None:\n        D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)\n    else:\n        assert D.shape == (nq, k)\n        assert D.device == xb.device\n\n    if I is None:\n        I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)\n    else:\n        assert I.shape == (nq, k)\n        assert I.device == xb.device\n\n    D_ptr = swig_ptr_from_FloatTensor(D)\n    I_ptr = swig_ptr_from_LongTensor(I)\n\n    faiss.bruteForceKnn(res, metric,\n                xb_ptr, xb_row_major, nb,\n                xq_ptr, xq_row_major, nq,\n                d, k, D_ptr, I_ptr)\n\n    return D, I\n\ndef index_init_gpu(ngpus, feat_dim):\n    flat_config = []\n    for i in range(ngpus):\n        cfg = faiss.GpuIndexFlatConfig()\n        cfg.useFloat16 = False\n        cfg.device = i\n        flat_config.append(cfg)\n\n    res = [faiss.StandardGpuResources() for i in range(ngpus)]\n    indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)]\n    index = faiss.IndexShards(feat_dim)\n    for sub_index in indexes:\n        index.add_shard(sub_index)\n    index.reset()\n    return index\n\ndef index_init_cpu(feat_dim):\n    return faiss.IndexFlatL2(feat_dim)\n"
  },
  {
    "path": "hhcl/utils/logging.py",
    "content": "from __future__ import absolute_import\nimport os\nimport sys\n\nfrom .osutils import mkdir_if_missing\n\n\nclass Logger(object):\n    def __init__(self, fpath=None):\n        self.console = sys.stdout\n        self.file = None\n        if fpath is not None:\n            mkdir_if_missing(os.path.dirname(fpath))\n            self.file = open(fpath, 'w')\n\n    def __del__(self):\n        self.close()\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, *args):\n        self.close()\n\n    def write(self, msg):\n        self.console.write(msg)\n        if self.file is not None:\n            self.file.write(msg)\n\n    def flush(self):\n        self.console.flush()\n        if self.file is not None:\n            self.file.flush()\n            os.fsync(self.file.fileno())\n\n    def close(self):\n        self.console.close()\n        if self.file is not None:\n            self.file.close()\n"
  },
  {
    "path": "hhcl/utils/meters.py",
    "content": "from __future__ import absolute_import\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n"
  },
  {
    "path": "hhcl/utils/osutils.py",
    "content": "from __future__ import absolute_import\nimport os\nimport errno\n\n\ndef mkdir_if_missing(dir_path):\n    try:\n        os.makedirs(dir_path)\n    except OSError as e:\n        if e.errno != errno.EEXIST:\n            raise\n"
  },
  {
    "path": "hhcl/utils/rerank.py",
    "content": "#!/usr/bin/env python2/python3\n# -*- coding: utf-8 -*-\n\"\"\"\nSource: https://github.com/zhunzhong07/person-re-ranking\nCreated on Mon Jun 26 14:46:56 2017\n@author: luohao\nModified by Houjing Huang, 2017-12-22.\n- This version accepts distance matrix instead of raw features.\n- The difference of `/` division between python 2 and 3 is handled.\n- numpy.float16 is replaced by numpy.float32 for numerical precision.\nCVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.\nurl:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf\nMatlab version: https://github.com/zhunzhong07/person-re-ranking\nAPI\nq_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery]\nq_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query]\ng_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery]\nk1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3)\nReturns:\n  final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery]\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\n__all__ = ['re_ranking']\n\nimport numpy as np\n\n\ndef re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):\n\n    # The following naming, e.g. gallery_num, is different from outer scope.\n    # Don't care about it.\n\n    original_dist = np.concatenate(\n      [np.concatenate([q_q_dist, q_g_dist], axis=1),\n       np.concatenate([q_g_dist.T, g_g_dist], axis=1)],\n      axis=0)\n    original_dist = np.power(original_dist, 2).astype(np.float32)\n    original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0))\n    V = np.zeros_like(original_dist).astype(np.float32)\n    initial_rank = np.argsort(original_dist).astype(np.int32)\n\n    query_num = q_g_dist.shape[0]\n    gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1]\n    all_num = gallery_num\n\n    for i in range(all_num):\n        # k-reciprocal neighbors\n        forward_k_neigh_index = initial_rank[i,:k1+1]\n        backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]\n        fi = np.where(backward_k_neigh_index==i)[0]\n        k_reciprocal_index = forward_k_neigh_index[fi]\n        k_reciprocal_expansion_index = k_reciprocal_index\n        for j in range(len(k_reciprocal_index)):\n            candidate = k_reciprocal_index[j]\n            candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1]\n            candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1]\n            fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]\n            candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]\n            if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index):\n                k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)\n\n        k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)\n        weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])\n        V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight)\n    original_dist = original_dist[:query_num,]\n    if k2 != 1:\n        V_qe = np.zeros_like(V,dtype=np.float32)\n        for i in range(all_num):\n            V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)\n        V = V_qe\n        del V_qe\n    del initial_rank\n    invIndex = []\n    for i in range(gallery_num):\n        invIndex.append(np.where(V[:,i] != 0)[0])\n\n    jaccard_dist = np.zeros_like(original_dist,dtype = np.float32)\n\n\n    for i in range(query_num):\n        temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32)\n        indNonZero = np.where(V[i,:] != 0)[0]\n        indImages = []\n        indImages = [invIndex[ind] for ind in indNonZero]\n        for j in range(len(indNonZero)):\n            temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])\n        jaccard_dist[i] = 1-temp_min/(2.-temp_min)\n\n    final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value\n    del original_dist\n    del V\n    del jaccard_dist\n    final_dist = final_dist[:query_num,query_num:]\n    return final_dist\n"
  },
  {
    "path": "hhcl/utils/serialization.py",
    "content": "from __future__ import print_function, absolute_import\nimport json\nimport os.path as osp\nimport shutil\n\nimport torch\nfrom torch.nn import Parameter\n\nfrom .osutils import mkdir_if_missing\n\n\ndef read_json(fpath):\n    with open(fpath, 'r') as f:\n        obj = json.load(f)\n    return obj\n\n\ndef write_json(obj, fpath):\n    mkdir_if_missing(osp.dirname(fpath))\n    with open(fpath, 'w') as f:\n        json.dump(obj, f, indent=4, separators=(',', ': '))\n\n\ndef save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):\n    mkdir_if_missing(osp.dirname(fpath))\n    torch.save(state, fpath)\n    if is_best:\n        shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))\n\n\ndef load_checkpoint(fpath):\n    if osp.isfile(fpath):\n        # checkpoint = torch.load(fpath)\n        checkpoint = torch.load(fpath, map_location=torch.device('cpu'))\n        print(\"=> Loaded checkpoint '{}'\".format(fpath))\n        return checkpoint\n    else:\n        raise ValueError(\"=> No checkpoint found at '{}'\".format(fpath))\n\n\ndef copy_state_dict(state_dict, model, strip=None):\n    tgt_state = model.state_dict()\n    copied_names = set()\n    for name, param in state_dict.items():\n        if strip is not None and name.startswith(strip):\n            name = name[len(strip):]\n        if name not in tgt_state:\n            continue\n        if isinstance(param, Parameter):\n            param = param.data\n        if param.size() != tgt_state[name].size():\n            print('mismatch:', name, param.size(), tgt_state[name].size())\n            continue\n        tgt_state[name].copy_(param)\n        copied_names.add(name)\n\n    missing = set(tgt_state.keys()) - copied_names\n    if len(missing) > 0:\n        print(\"missing keys in state_dict:\", missing)\n\n    return model\n"
  },
  {
    "path": "requirements.txt",
    "content": "numpy\nsklearn\nCython\nh5py\npyzmq\npillow-simd\nsix\nscipy\nmatplotlib\nfaiss-gpu==1.6.3\neasydict"
  },
  {
    "path": "run.sh",
    "content": "### resnet50 ###\n# market1501\nCUDA_VISIBLE_DEVICES=0,1,2,3 python examples/train.py -b 256 -a resnet50 -d market1501 --iters 200 --eps 0.45 --num-instances 16 --pooling-type avg --memorybank CMhybrid --epochs 60 --logs-dir examples/logs/market1501/resnet50_avg_cmhybrid\n\n# dukemtmcreid\nCUDA_VISIBLE_DEVICES=0,1,2,3 python examples/train.py -b 256 -a resnet50 -d dukemtmcreid --iters 200 --eps 0.6 --num-instances 16 --pooling-type avg --memorybank CMhybrid --epochs 60 --logs-dir examples/logs/dukemtmcreid/resnet50_avg_cmhybrid\n\n\n### resnet_ibn50a + gem pooling ###\n# market1501\nCUDA_VISIBLE_DEVICES=0,1,2,3 python examples/train.py -b 256 -a resnet_ibn50a -d market1501 --iters 200 --eps 0.45 --num-instances 16 --pooling-type gem --memorybank CMhybrid --epochs 60 --logs-dir examples/logs/market1501/resnet50_ibn_gem_cmhybrid\n\n# dukemtmcreid\nCUDA_VISIBLE_DEVICES=0,1,2,3 python examples/train.py -b 256 -a resnet_ibn50a -d dukemtmcreid --iters 200 --eps 0.6 --num-instances 16 --pooling-type gem --memorybank CMhybrid --epochs 60 --logs-dir examples/logs/dukemtmcreid/resnet50_ibn_gem_cmhybrid\n\n# test\nCUDA_VISIBLE_DEVICES=0 python examples/test.py -d market1501 --data-dir examples/data/market1501 --pooling-type avg --resume examples/logs/market1501/resnet50_avg_cmhybrid/model_best.pth.tar"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\n\nsetup(name='hhcl',\n      version='1.0.0',\n      install_requires=[\n          'numpy', 'torch', 'torchvision',\n          'six', 'h5py', 'Pillow', 'scipy',\n          'scikit-learn', 'metric-learn', 'faiss_gpu'],\n      packages=find_packages(),\n      keywords=[\n          'Unsupervised Learning',\n          'Contrastive Learning',\n          'Object Re-identification'\n      ])\n"
  }
]