Repository: MalongTech/research-ms-loss Branch: master Commit: b68507d4e22d Files: 60 Total size: 115.8 KB Directory structure: gitextract_ibl22pv3/ ├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── ThirdPartyNotices.txt ├── configs/ │ ├── example.yaml │ ├── example_margin.yaml │ └── example_resnet50.yaml ├── requirements.txt ├── ret_benchmark/ │ ├── config/ │ │ ├── __init__.py │ │ ├── defaults.py │ │ └── model_path.py │ ├── data/ │ │ ├── __init__.py │ │ ├── build.py │ │ ├── collate_batch.py │ │ ├── datasets/ │ │ │ ├── __init__.py │ │ │ └── base_dataset.py │ │ ├── evaluations/ │ │ │ ├── __init__.py │ │ │ └── ret_metric.py │ │ ├── samplers/ │ │ │ ├── __init__.py │ │ │ └── random_identity_sampler.py │ │ └── transforms/ │ │ ├── __init__.py │ │ └── build.py │ ├── engine/ │ │ ├── __init__.py │ │ └── trainer.py │ ├── losses/ │ │ ├── __init__.py │ │ ├── build.py │ │ ├── margin_loss.py │ │ ├── multi_similarity_loss.py │ │ └── registry.py │ ├── modeling/ │ │ ├── __init__.py │ │ ├── backbone/ │ │ │ ├── __init__.py │ │ │ ├── bninception.py │ │ │ ├── build.py │ │ │ └── resnet.py │ │ ├── build.py │ │ ├── heads/ │ │ │ ├── __init__.py │ │ │ ├── build.py │ │ │ └── linear_norm.py │ │ ├── registry.py │ │ └── xbm.py │ ├── solver/ │ │ ├── __init__.py │ │ ├── build.py │ │ └── lr_scheduler.py │ └── utils/ │ ├── checkpoint.py │ ├── config_util.py │ ├── feat_extractor.py │ ├── freeze_bn.py │ ├── img_reader.py │ ├── init_methods.py │ ├── logger.py │ ├── metric_logger.py │ ├── model_serialization.py │ └── registry.py ├── scripts/ │ ├── prepare_cub.sh │ ├── run_cub.sh │ ├── run_cub_margin.sh │ └── split_cub_for_ms_loss.py ├── setup.py └── tools/ └── main.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .flake8 ================================================ [flake8] ignore = F401, F841, E402, E722, E999 max-line-length = 128 max-complexity=18 format=pylint show_source = True statistics = True count = True exclude = tests,ret_benchmark/modeling/backbone ================================================ FILE: .gitignore ================================================ resource build *.pyc *.zip */__pycache__ __pycache__ # Package Files # *.pkl *.log *.jar *.war *.nar *.ear *.zip *.tar.gz *.rar *.egg-info #some local files */.settings/ */.DS_Store .DS_Store */.idea/ .idea/ gradlew gradlew.bat unused.txt output/ *.egg-info/ ================================================ FILE: LICENSE ================================================ Creative Commons Attribution-NonCommercial 4.0 International (CC-BY-NC-4.0) Public License For Multi-Similarity Loss for Deep Metric Learning (MS-Loss) Copyright (c) 2014-present, Malong Technologies Co., Ltd. All rights reserved. By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. Section 1 -- Definitions. a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. c. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. d. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. e. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. f. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License. g. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. h. Licensor means the individual(s) or entity(ies) granting rights under this Public License. i. NonCommercial means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. j. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. k. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. l. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. Section 2 -- Scope. a. License grant. 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: a. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and b. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 3. Term. The term of this Public License is specified in Section 6(a). 4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a) (4) never produces Adapted Material. 5. Downstream recipients. a. Offer from the Licensor -- Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. b. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). b. Other rights. 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 2. Patent and trademark rights are not licensed under this Public License. 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. Section 3 -- License Conditions. Your exercise of the Licensed Rights is expressly made subject to the following conditions. a. Attribution. 1. If You Share the Licensed Material (including in modified form), You must: a. retain the following if it is supplied by the Licensor with the Licensed Material: i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); ii. a copyright notice; iii. a notice that refers to this Public License; iv. a notice that refers to the disclaimer of warranties; v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; b. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and c. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. Section 4 -- Sui Generis Database Rights. Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. Section 5 -- Disclaimer of Warranties and Limitation of Liability. a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. Section 6 -- Term and Termination. a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 2. upon express reinstatement by the Licensor. For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. Section 7 -- Other Terms and Conditions. a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. Section 8 -- Interpretation. a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. ================================================ FILE: README.md ================================================ [![License: CC BY-NC 4.0](https://licensebuttons.net/l/by-nc/4.0/80x15.png)](https://creativecommons.org/licenses/by-nc/4.0/) # Multi-Similarity Loss for Deep Metric Learning (MS-Loss) Code for the CVPR 2019 paper [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf) ### Performance compared with SOTA methods on CUB-200-2011 |Rank@K | 1 | 2 | 4 | 8 | 16 | 32 | |:--- |:-:|:-:|:-:|:-:|:-: |:-: | |Clustering64 | 48.2 | 61.4 | 71.8 | 81.9 | - | - | |ProxyNCA64 | 49.2 | 61.9 | 67.9 | 72.4 | - | - | |Smart Mining64 | 49.8 | 62.3 | 74.1 | 83.3 | - | |Our MS-Loss64| **57.4** |**69.8** |**80.0** |**87.8** |93.2 |96.4| |HTL512 | 57.1| 68.8| 78.7| 86.5| 92.5| 95.5 | |ABIER512 |57.5 |68.7 |78.3 |86.2 |91.9 |95.5 | |Our MS-Loss512|**65.7** |**77.0** |**86.3**|**91.2** |**95.0** |**97.3**| ### Prepare the data and the pretrained model The following script will prepare the [CUB](http://www.vision.caltech.edu.s3-us-west-2.amazonaws.com/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) dataset for training by downloading to the ./resource/datasets/ folder; which will then build the data list (train.txt test.txt): ```bash ./scripts/prepare_cub.sh ``` Download the imagenet pretrained model of [bninception](http://data.lip6.fr/cadene/pretrainedmodels/bn_inception-52deb4733.pth) and put it in the folder: ~/.torch/models/. ### Installation ```bash pip install -r requirements.txt python setup.py develop build ``` ### Train and Test on CUB200-2011 with MS-Loss ```bash ./scripts/run_cub.sh ``` Trained models will be saved in the ./output/ folder if using the default config. Best recall@1 higher than 66 (65.7 in the paper). ### Contact For any questions, please feel free to reach ``` github@malongtech.com ``` ### Citation If you use this method or this code in your research, please cite as: @inproceedings{wang2019multi, title={Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning}, author={Wang, Xun and Han, Xintong and Huang, Weilin and Dong, Dengke and Scott, Matthew R}, booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, pages={5022--5030}, year={2019} } ## License MS-Loss is CC-BY-NC 4.0 licensed, as found in the [LICENSE](LICENSE) file. It is released for academic research / non-commercial use only. If you wish to use for commercial purposes, please contact sales@malongtech.com. ================================================ FILE: ThirdPartyNotices.txt ================================================ THIRD PARTY SOFTWARE NOTICES AND INFORMATION Do Not Translate or Localize This software incorporates material from the following third parties. _____ Cadene/pretrained-models.pytorch BSD 3-Clause License Copyright (c) 2017, Remi Cadene All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. _____ facebookresearch/maskrcnn-benchmark MIT License Copyright (c) 2018 Facebook Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: configs/example.yaml ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. MODEL: BACKBONE: NAME: bninception SOLVER: MAX_ITERS: 3000 STEPS: [1200, 2400] OPTIMIZER_NAME: Adam BASE_LR: 0.00003 WARMUP_ITERS: 0 WEIGHT_DECAY: 0.0005 DATA: TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt TRAIN_BATCHSIZE: 80 TEST_BATCHSIZE: 256 NUM_WORKERS: 8 NUM_INSTANCES: 5 VALIDATION: VERBOSE: 200 ================================================ FILE: configs/example_margin.yaml ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. MODEL: BACKBONE: NAME: bninception LOSSES: NAME: margin_loss MARGIN_LOSS: N_CLASSES: 100 BETA_CONSTANT: False # if False (i.e. class specific beta) train.txt should have labels 0 .... N_CLASSES -1 SOLVER: MAX_ITERS: 3000 STEPS: [1200, 2400] OPTIMIZER_NAME: Adam BASE_LR: 0.00003 WARMUP_ITERS: 0 WEIGHT_DECAY: 0.0005 DATA: TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt TRAIN_BATCHSIZE: 120 TEST_BATCHSIZE: 256 NUM_WORKERS: 8 NUM_INSTANCES: 5 VALIDATION: VERBOSE: 200 SAVE_DIR: output_margin ================================================ FILE: configs/example_resnet50.yaml ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. MODEL: BACKBONE: NAME: resnet50 INPUT: MODE: 'RGB' PIXEL_MEAN: [0.485, 0.456, 0.406] PIXEL_STD: [0.229, 0.224, 0.225] SOLVER: MAX_ITERS: 3000 STEPS: [1200, 2400] OPTIMIZER_NAME: Adam BASE_LR: 0.00003 WARMUP_ITERS: 0 WEIGHT_DECAY: 0.0005 DATA: TRAIN_IMG_SOURCE: resource/datasets/CUB_200_2011/train.txt TEST_IMG_SOURCE: resource/datasets/CUB_200_2011/test.txt TRAIN_BATCHSIZE: 80 TEST_BATCHSIZE: 256 NUM_WORKERS: 8 NUM_INSTANCES: 5 VALIDATION: VERBOSE: 200 ================================================ FILE: requirements.txt ================================================ torch==1.1.0 numpy==1.15.4 yacs==0.1.4 setuptools==40.6.2 pytest==4.4.0 Pillow==8.3.2 torchvision==0.3.0 ================================================ FILE: ret_benchmark/config/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .defaults import _C as cfg ================================================ FILE: ret_benchmark/config/defaults.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from yacs.config import CfgNode as CN from .model_path import MODEL_PATH # ----------------------------------------------------------------------------- # Config definition # ----------------------------------------------------------------------------- _C = CN() _C.MODEL = CN() _C.MODEL.DEVICE = "cuda" _C.MODEL.BACKBONE = CN() _C.MODEL.BACKBONE.NAME = "bninception" _C.MODEL.PRETRAIN = 'imagenet' _C.MODEL.PRETRIANED_PATH = MODEL_PATH _C.MODEL.HEAD = CN() _C.MODEL.HEAD.NAME = "linear_norm" _C.MODEL.HEAD.DIM = 512 _C.MODEL.WEIGHT = "" # Checkpoint save dir _C.SAVE_DIR = 'output' # Loss _C.LOSSES = CN() _C.LOSSES.NAME = 'ms_loss' # ms loss _C.LOSSES.MULTI_SIMILARITY_LOSS = CN() _C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS = 2.0 _C.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG = 40.0 _C.LOSSES.MULTI_SIMILARITY_LOSS.HARD_MINING = True # margin loss _C.LOSSES.MARGIN_LOSS = CN() _C.LOSSES.MARGIN_LOSS.BETA_CONSTANT = False _C.LOSSES.MARGIN_LOSS.N_CLASSES = 100 _C.LOSSES.MARGIN_LOSS.BETA_CONSTANT = False _C.LOSSES.MARGIN_LOSS.CUTOFF = 0.5 _C.LOSSES.MARGIN_LOSS.UPPER_CUTOFF = 1.4 # Data option _C.DATA = CN() _C.DATA.TRAIN_IMG_SOURCE = 'resource/datasets/CUB_200_2011/train.txt' _C.DATA.TEST_IMG_SOURCE = 'resource/datasets/CUB_200_2011/test.txt' _C.DATA.TRAIN_BATCHSIZE = 70 _C.DATA.TEST_BATCHSIZE = 256 _C.DATA.NUM_WORKERS = 8 _C.DATA.NUM_INSTANCES = 5 # Input option _C.INPUT = CN() # INPUT CONFIG _C.INPUT.MODE = 'BGR' _C.INPUT.PIXEL_MEAN = [104. / 255, 117. / 255, 128. / 255] _C.INPUT.PIXEL_STD = 3 * [1. / 255] _C.INPUT.FLIP_PROB = 0.5 _C.INPUT.ORIGIN_SIZE = 256 _C.INPUT.CROP_SCALE = [0.16, 1] _C.INPUT.CROP_SIZE = 227 # SOLVER _C.SOLVER = CN() _C.SOLVER.IS_FINETURN = False _C.SOLVER.FINETURN_MODE_PATH = '' _C.SOLVER.MAX_ITERS = 4000 _C.SOLVER.STEPS = [1000, 2000, 3000] _C.SOLVER.OPTIMIZER_NAME = 'SGD' _C.SOLVER.BASE_LR = 0.01 _C.SOLVER.BIAS_LR_FACTOR = 1 _C.SOLVER.WEIGHT_DECAY = 0.0005 _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005 _C.SOLVER.MOMENTUM = 0.9 _C.SOLVER.GAMMA = 0.1 _C.SOLVER.WARMUP_FACTOR = 0.01 _C.SOLVER.WARMUP_ITERS = 200 _C.SOLVER.WARMUP_METHOD = 'linear' _C.SOLVER.CHECKPOINT_PERIOD = 200 _C.SOLVER.RNG_SEED = 1 # Logger _C.LOGGER = CN() _C.LOGGER.LEVEL = 20 _C.LOGGER.STREAM = 'stdout' # Validation _C.VALIDATION = CN() _C.VALIDATION.VERBOSE = 200 _C.VALIDATION.IS_VALIDATION = True ================================================ FILE: ret_benchmark/config/model_path.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. # ----------------------------------------------------------------------------- # Config definition of imagenet pretrained model path # ----------------------------------------------------------------------------- from yacs.config import CfgNode as CN MODEL_PATH = { 'bninception': "~/.torch/models/bn_inception-52deb4733.pth", 'resnet50': "~/.torch/models/resnet50-19c8e357.pth", } MODEL_PATH = CN(MODEL_PATH) ================================================ FILE: ret_benchmark/data/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .build import build_data ================================================ FILE: ret_benchmark/data/build.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from torch.utils.data import DataLoader from .collate_batch import collate_fn from .datasets import BaseDataSet from .samplers import RandomIdentitySampler from .transforms import build_transforms def build_data(cfg, is_train=True): transforms = build_transforms(cfg, is_train=is_train) if is_train: dataset = BaseDataSet(cfg.DATA.TRAIN_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE) sampler = RandomIdentitySampler(dataset=dataset, batch_size=cfg.DATA.TRAIN_BATCHSIZE, num_instances=cfg.DATA.NUM_INSTANCES, max_iters=cfg.SOLVER.MAX_ITERS ) data_loader = DataLoader(dataset, collate_fn=collate_fn, batch_sampler=sampler, num_workers=cfg.DATA.NUM_WORKERS, pin_memory=True ) else: dataset = BaseDataSet(cfg.DATA.TEST_IMG_SOURCE, transforms=transforms, mode=cfg.INPUT.MODE) data_loader = DataLoader(dataset, collate_fn=collate_fn, shuffle=False, batch_size=cfg.DATA.TEST_BATCHSIZE, num_workers=cfg.DATA.NUM_WORKERS ) return data_loader ================================================ FILE: ret_benchmark/data/collate_batch.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import torch def collate_fn(batch): imgs, labels = zip(*batch) labels = [int(k) for k in labels] labels = torch.tensor(labels, dtype=torch.int64) return torch.stack(imgs, dim=0), labels ================================================ FILE: ret_benchmark/data/datasets/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .base_dataset import BaseDataSet ================================================ FILE: ret_benchmark/data/datasets/base_dataset.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import os import re from collections import defaultdict from torch.utils.data import Dataset from ret_benchmark.utils.img_reader import read_image class BaseDataSet(Dataset): """ Basic Dataset read image path from img_source img_source: list of img_path and label """ def __init__(self, img_source, transforms=None, mode="RGB"): self.mode = mode self.transforms = transforms self.root = os.path.dirname(img_source) assert os.path.exists(img_source), f"{img_source} NOT found." self.img_source = img_source self.label_list = list() self.path_list = list() self._load_data() self.label_index_dict = self._build_label_index_dict() def __len__(self): return len(self.label_list) def __repr__(self): return self.__str__() def __str__(self): return f"| Dataset Info |datasize: {self.__len__()}|num_labels: {len(set(self.label_list))}|" def _load_data(self): with open(self.img_source, 'r') as f: for line in f: _path, _label = re.split(r",| ", line.strip()) self.path_list.append(_path) self.label_list.append(_label) def _build_label_index_dict(self): index_dict = defaultdict(list) for i, label in enumerate(self.label_list): index_dict[label].append(i) return index_dict def __getitem__(self, index): path = self.path_list[index] img_path = os.path.join(self.root, path) label = self.label_list[index] img = read_image(img_path, mode=self.mode) if self.transforms is not None: img = self.transforms(img) return img, label ================================================ FILE: ret_benchmark/data/evaluations/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .ret_metric import RetMetric ================================================ FILE: ret_benchmark/data/evaluations/ret_metric.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import numpy as np class RetMetric(object): def __init__(self, feats, labels): if len(feats) == 2 and type(feats) == list: """ feats = [gallery_feats, query_feats] labels = [gallery_labels, query_labels] """ self.is_equal_query = False self.gallery_feats, self.query_feats = feats self.gallery_labels, self.query_labels = labels else: self.is_equal_query = True self.gallery_feats = self.query_feats = feats self.gallery_labels = self.query_labels = labels self.sim_mat = np.matmul(self.query_feats, np.transpose(self.gallery_feats)) def recall_k(self, k=1): m = len(self.sim_mat) match_counter = 0 for i in range(m): pos_sim = self.sim_mat[i][self.gallery_labels == self.query_labels[i]] neg_sim = self.sim_mat[i][self.gallery_labels != self.query_labels[i]] thresh = np.sort(pos_sim)[-2] if self.is_equal_query else np.max(pos_sim) if np.sum(neg_sim > thresh) < k: match_counter += 1 return float(match_counter) / m ================================================ FILE: ret_benchmark/data/samplers/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .random_identity_sampler import RandomIdentitySampler ================================================ FILE: ret_benchmark/data/samplers/random_identity_sampler.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import copy import random from collections import defaultdict import numpy as np import torch from torch.utils.data.sampler import Sampler class RandomIdentitySampler(Sampler): """ Randomly sample N identities, then for each identity, randomly sample K instances, therefore batch size is N*K. Args: - dataset (BaseDataSet). - num_instances (int): number of instances per identity in a batch. - batch_size (int): number of examples in a batch. """ def __init__(self, dataset, batch_size, num_instances, max_iters): self.label_index_dict = dataset.label_index_dict self.batch_size = batch_size self.K = num_instances self.num_labels_per_batch = self.batch_size // self.K self.max_iters = max_iters self.labels = list(self.label_index_dict.keys()) def __len__(self): return self.max_iters def __repr__(self): return self.__str__() def __str__(self): return f"|Sampler| iters {self.max_iters}| K {self.K}| M {self.batch_size}|" def _prepare_batch(self): batch_idxs_dict = defaultdict(list) for label in self.labels: idxs = copy.deepcopy(self.label_index_dict[label]) if len(idxs) < self.K: idxs.extend(np.random.choice(idxs, size=self.K - len(idxs), replace=True)) random.shuffle(idxs) batch_idxs_dict[label] = [idxs[i * self.K: (i + 1) * self.K] for i in range(len(idxs) // self.K)] avai_labels = copy.deepcopy(self.labels) return batch_idxs_dict, avai_labels def __iter__(self): batch_idxs_dict, avai_labels = self._prepare_batch() for _ in range(self.max_iters): batch = [] if len(avai_labels) < self.num_labels_per_batch: batch_idxs_dict, avai_labels = self._prepare_batch() selected_labels = random.sample(avai_labels, self.num_labels_per_batch) for label in selected_labels: batch_idxs = batch_idxs_dict[label].pop(0) batch.extend(batch_idxs) if len(batch_idxs_dict[label]) == 0: avai_labels.remove(label) yield batch ================================================ FILE: ret_benchmark/data/transforms/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .build import build_transforms ================================================ FILE: ret_benchmark/data/transforms/build.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import torchvision.transforms as T def build_transforms(cfg, is_train=True): normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) if is_train: transform = T.Compose([ T.Resize(size=cfg.INPUT.ORIGIN_SIZE), T.RandomResizedCrop( scale=cfg.INPUT.CROP_SCALE, size=cfg.INPUT.CROP_SIZE ), T.RandomHorizontalFlip(p=cfg.INPUT.FLIP_PROB), T.ToTensor(), normalize_transform, ]) else: transform = T.Compose([ T.Resize(size=cfg.INPUT.ORIGIN_SIZE), T.CenterCrop(cfg.INPUT.CROP_SIZE), T.ToTensor(), normalize_transform ]) return transform ================================================ FILE: ret_benchmark/engine/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .trainer import do_train ================================================ FILE: ret_benchmark/engine/trainer.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import datetime import time import numpy as np import torch from ret_benchmark.data.evaluations import RetMetric from ret_benchmark.utils.feat_extractor import feat_extractor from ret_benchmark.utils.freeze_bn import set_bn_eval from ret_benchmark.utils.metric_logger import MetricLogger def do_train( cfg, model, train_loader, val_loader, optimizer, scheduler, criterion, checkpointer, device, checkpoint_period, arguments, logger ): logger.info("Start training") meters = MetricLogger(delimiter=" ") max_iter = len(train_loader) start_iter = arguments["iteration"] best_iteration = -1 best_recall = 0 start_training_time = time.time() end = time.time() for iteration, (images, targets) in enumerate(train_loader, start_iter): if iteration % cfg.VALIDATION.VERBOSE == 0 or iteration == max_iter: model.eval() logger.info('Validation') labels = val_loader.dataset.label_list labels = np.array([int(k) for k in labels]) feats = feat_extractor(model, val_loader, logger=logger) ret_metric = RetMetric(feats=feats, labels=labels) recall_curr = ret_metric.recall_k(1) if recall_curr > best_recall: best_recall = recall_curr best_iteration = iteration logger.info(f'Best iteration {iteration}: recall@1: {best_recall:.3f}') checkpointer.save(f"best_model") else: logger.info(f'Recall@1 at iteration {iteration:06d}: {recall_curr:.3f}') model.train() model.apply(set_bn_eval) data_time = time.time() - end iteration = iteration + 1 arguments["iteration"] = iteration scheduler.step() images = images.to(device) targets = torch.stack([target.to(device) for target in targets]) feats = model(images) loss = criterion(feats, targets) optimizer.zero_grad() loss.backward() optimizer.step() batch_time = time.time() - end end = time.time() meters.update(time=batch_time, data=data_time, loss=loss.item()) eta_seconds = meters.time.global_avg * (max_iter - iteration) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if iteration % 20 == 0 or iteration == max_iter: logger.info( meters.delimiter.join( [ "eta: {eta}", "iter: {iter}", "{meters}", "lr: {lr:.6f}", "max mem: {memory:.1f} GB", ] ).format( eta=eta_string, iter=iteration, meters=str(meters), lr=optimizer.param_groups[0]["lr"], memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0, ) ) if iteration % checkpoint_period == 0: checkpointer.save("model_{:06d}".format(iteration)) total_training_time = time.time() - start_training_time total_time_str = str(datetime.timedelta(seconds=total_training_time)) logger.info( "Total training time: {} ({:.4f} s / it)".format( total_time_str, total_training_time / (max_iter) ) ) logger.info(f"Best iteration: {best_iteration :06d} | best recall {best_recall} ") ================================================ FILE: ret_benchmark/losses/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .build import build_loss ================================================ FILE: ret_benchmark/losses/build.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .multi_similarity_loss import MultiSimilarityLoss from .margin_loss import MarginLoss from .registry import LOSS def build_loss(cfg): loss_name = cfg.LOSSES.NAME assert loss_name in LOSS, \ f'loss name {loss_name} is not registered in registry :{LOSS.keys()}' return LOSS[loss_name](cfg) ================================================ FILE: ret_benchmark/losses/margin_loss.py ================================================ import numpy as np import torch from torch import nn import torch.nn.functional as F from ret_benchmark.losses.registry import LOSS class DistanceWeightedSampling(object): """ """ def __init__(self, cfg): super(DistanceWeightedSampling, self).__init__() self.cutoff = cfg.LOSSES.MARGIN_LOSS.CUTOFF self.upper_cutoff = cfg.LOSSES.MARGIN_LOSS.UPPER_CUTOFF def sample(self, batch, labels): if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() bs = batch.shape[0] distances = self.p_dist(batch.detach()).clamp(min=self.cutoff) positives, negatives = [], [] for i in range(bs): pos = labels == labels[i] q_d_inv = self.inverse_sphere_distances(batch, distances[i], labels, labels[i]) # sample positives randomly pos[i] = 0 positives.append(np.random.choice(np.where(pos)[0])) # sample negatives by distance negatives.append(np.random.choice(bs, p=q_d_inv)) sampled_triplets = [[a, p, n] for a, p, n in zip(list(range(bs)), positives, negatives)] return sampled_triplets @staticmethod def p_dist(A, eps=1e-4): prod = torch.mm(A, A.t()) norm = prod.diag().unsqueeze(1).expand_as(prod) res = (norm + norm.t() - 2 * prod).clamp(min=0) return res.clamp(min=eps).sqrt() def inverse_sphere_distances(self, batch, dist, labels, anchor_label): bs, dim = len(dist), batch.shape[-1] # negated log-distribution of distances of unit sphere in dimension log_q_d_inv = ((2.0 - float(dim)) * torch.log(dist) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dist.pow(2)))) # set sampling probabilities of positives to zero log_q_d_inv[np.where(labels == anchor_label)[0]] = 0 q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability # set sampling probabilities of positives to zero q_d_inv[np.where(labels == anchor_label)[0]] = 0 # NOTE: Cutting of values with high distances made the results slightly worse. # q_d_inv[np.where(dist > self.upper_cutoff)[0]] = 0 q_d_inv = q_d_inv/q_d_inv.sum() return q_d_inv.detach().cpu().numpy() @LOSS.register("margin_loss") class MarginLoss(nn.Module): """Margin based loss with DistanceWeightedSampling """ def __init__(self, cfg): super(MarginLoss, self).__init__() self.beta_val = 1.2 self.margin = 0.2 self.nu = 0.0 self.n_classes = cfg.LOSSES.MARGIN_LOSS.N_CLASSES self.beta_constant = cfg.LOSSES.MARGIN_LOSS.BETA_CONSTANT if self.beta_constant: self.beta = self.beta_val else: self.beta = torch.nn.Parameter(torch.ones(self.n_classes)*self.beta_val) self.sampler = DistanceWeightedSampling(cfg) def forward(self, batch, labels): if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() sampled_triplets = self.sampler.sample(batch, labels) # compute distances between anchor-positive and anchor-negative. d_ap, d_an = [], [] for triplet in sampled_triplets: train_triplet = {'Anchor': batch[triplet[0], :], 'Positive': batch[triplet[1], :], 'Negative': batch[triplet[2]]} pos_dist = ((train_triplet['Anchor']-train_triplet['Positive']).pow(2).sum()+1e-8).pow(1/2) neg_dist = ((train_triplet['Anchor']-train_triplet['Negative']).pow(2).sum()+1e-8).pow(1/2) d_ap.append(pos_dist) d_an.append(neg_dist) d_ap, d_an = torch.stack(d_ap), torch.stack(d_an) # group betas together by anchor class in sampled triplets (as each beta belongs to one class). if self.beta_constant: beta = self.beta else: beta = torch.stack([self.beta[labels[triplet[0]]] for triplet in sampled_triplets]).type(torch.cuda.FloatTensor) # compute actual margin positive and margin negative loss pos_loss = F.relu(d_ap-beta+self.margin) neg_loss = F.relu(beta-d_an+self.margin) # compute normalization constant pair_count = torch.sum((pos_loss > 0.)+(neg_loss > 0.)).type(torch.cuda.FloatTensor) # actual Margin Loss loss = torch.sum(pos_loss+neg_loss) if pair_count == 0. else torch.sum(pos_loss+neg_loss)/pair_count # (Optional) Add regularization penalty on betas. # if self.nu: loss = loss + beta_regularisation_loss.type(torch.cuda.FloatTensor) return loss ================================================ FILE: ret_benchmark/losses/multi_similarity_loss.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import torch from torch import nn from ret_benchmark.losses.registry import LOSS @LOSS.register('ms_loss') class MultiSimilarityLoss(nn.Module): def __init__(self, cfg): super(MultiSimilarityLoss, self).__init__() self.thresh = 0.5 self.margin = 0.1 self.scale_pos = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_POS self.scale_neg = cfg.LOSSES.MULTI_SIMILARITY_LOSS.SCALE_NEG def forward(self, feats, labels): assert feats.size(0) == labels.size(0), \ f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}" batch_size = feats.size(0) sim_mat = torch.matmul(feats, torch.t(feats)) epsilon = 1e-5 loss = list() for i in range(batch_size): pos_pair_ = sim_mat[i][labels == labels[i]] pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] neg_pair_ = sim_mat[i][labels != labels[i]] neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)] if len(neg_pair) < 1 or len(pos_pair) < 1: continue # weighting step pos_loss = 1.0 / self.scale_pos * torch.log( 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh)))) neg_loss = 1.0 / self.scale_neg * torch.log( 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh)))) loss.append(pos_loss + neg_loss) if len(loss) == 0: return torch.zeros([], requires_grad=True) loss = sum(loss) / batch_size return loss ================================================ FILE: ret_benchmark/losses/registry.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from ret_benchmark.utils.registry import Registry LOSS = Registry() ================================================ FILE: ret_benchmark/modeling/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .backbone import build_backbone from .build import build_model from .heads import build_head from .registry import BACKBONES, HEADS ================================================ FILE: ret_benchmark/modeling/backbone/__init__.py ================================================ from .build import build_backbone ================================================ FILE: ret_benchmark/modeling/backbone/bninception.py ================================================ from __future__ import absolute_import, division, print_function import torch import torch.nn as nn import torch.nn.functional as F from ret_benchmark.modeling import registry @registry.BACKBONES.register('bninception') class BNInception(nn.Module): def __init__(self): super(BNInception, self).__init__() inplace = True self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True) self.conv1_relu_7x7 = nn.ReLU(inplace) self.pool1_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) self.conv2_relu_3x3_reduce = nn.ReLU(inplace) self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True) self.conv2_relu_3x3 = nn.ReLU(inplace) self.pool2_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True) self.inception_3a_relu_1x1 = nn.ReLU(inplace) self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) self.inception_3a_relu_3x3_reduce = nn.ReLU(inplace) self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True) self.inception_3a_relu_3x3 = nn.ReLU(inplace) self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) self.inception_3a_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) self.inception_3a_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) self.inception_3a_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_3a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True) self.inception_3a_relu_pool_proj = nn.ReLU(inplace) self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True) self.inception_3b_relu_1x1 = nn.ReLU(inplace) self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) self.inception_3b_relu_3x3_reduce = nn.ReLU(inplace) self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True) self.inception_3b_relu_3x3 = nn.ReLU(inplace) self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) self.inception_3b_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) self.inception_3b_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) self.inception_3b_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_3b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True) self.inception_3b_relu_pool_proj = nn.ReLU(inplace) self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) self.inception_3c_relu_3x3_reduce = nn.ReLU(inplace) self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True) self.inception_3c_relu_3x3 = nn.ReLU(inplace) self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1)) self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) self.inception_3c_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) self.inception_3c_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) self.inception_3c_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_3c_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1)) self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True) self.inception_4a_relu_1x1 = nn.ReLU(inplace) self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1)) self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) self.inception_4a_relu_3x3_reduce = nn.ReLU(inplace) self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True) self.inception_4a_relu_3x3 = nn.ReLU(inplace) self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) self.inception_4a_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) self.inception_4a_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) self.inception_4a_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_4a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) self.inception_4a_relu_pool_proj = nn.ReLU(inplace) self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1)) self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True) self.inception_4b_relu_1x1 = nn.ReLU(inplace) self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) self.inception_4b_relu_3x3_reduce = nn.ReLU(inplace) self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True) self.inception_4b_relu_3x3 = nn.ReLU(inplace) self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) self.inception_4b_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) self.inception_4b_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) self.inception_4b_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_4b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) self.inception_4b_relu_pool_proj = nn.ReLU(inplace) self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True) self.inception_4c_relu_1x1 = nn.ReLU(inplace) self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) self.inception_4c_relu_3x3_reduce = nn.ReLU(inplace) self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True) self.inception_4c_relu_3x3 = nn.ReLU(inplace) self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) self.inception_4c_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True) self.inception_4c_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True) self.inception_4c_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_4c_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True) self.inception_4c_relu_pool_proj = nn.ReLU(inplace) self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1)) self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True) self.inception_4d_relu_1x1 = nn.ReLU(inplace) self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) self.inception_4d_relu_3x3_reduce = nn.ReLU(inplace) self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True) self.inception_4d_relu_3x3 = nn.ReLU(inplace) self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1)) self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) self.inception_4d_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True) self.inception_4d_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True) self.inception_4d_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_4d_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True) self.inception_4d_relu_pool_proj = nn.ReLU(inplace) self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) self.inception_4e_relu_3x3_reduce = nn.ReLU(inplace) self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True) self.inception_4e_relu_3x3 = nn.ReLU(inplace) self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1)) self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) self.inception_4e_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True) self.inception_4e_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True) self.inception_4e_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_4e_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1)) self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True) self.inception_5a_relu_1x1 = nn.ReLU(inplace) self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1)) self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) self.inception_5a_relu_3x3_reduce = nn.ReLU(inplace) self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True) self.inception_5a_relu_3x3 = nn.ReLU(inplace) self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1)) self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) self.inception_5a_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) self.inception_5a_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) self.inception_5a_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_5a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) self.inception_5a_relu_pool_proj = nn.ReLU(inplace) self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1)) self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True) self.inception_5b_relu_1x1 = nn.ReLU(inplace) self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) self.inception_5b_relu_3x3_reduce = nn.ReLU(inplace) self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True) self.inception_5b_relu_3x3 = nn.ReLU(inplace) self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace) self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace) self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) self.inception_5b_relu_double_3x3_2 = nn.ReLU(inplace) self.inception_5b_pool = nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True) self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1)) self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) self.inception_5b_relu_pool_proj = nn.ReLU(inplace) def features(self, input): conv1_7x7_s2_out = self.conv1_7x7_s2(input) conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out) pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out) conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out) conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out) conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out) conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out) conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out) conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out) pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out) inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out) inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out) inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out) inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out) inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out) inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out) inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out) inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out) inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out) inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out) inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn( inception_3a_double_3x3_reduce_out) inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce( inception_3a_double_3x3_reduce_bn_out) inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out) inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out) inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out) inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out) inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out) inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out) inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out) inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out) inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out) inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out) inception_3a_output_out = torch.cat( [inception_3a_relu_1x1_out, inception_3a_relu_3x3_out, inception_3a_relu_double_3x3_2_out, inception_3a_relu_pool_proj_out], 1) inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out) inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out) inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out) inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out) inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out) inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out) inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out) inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out) inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out) inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out) inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn( inception_3b_double_3x3_reduce_out) inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce( inception_3b_double_3x3_reduce_bn_out) inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out) inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out) inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out) inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out) inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out) inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out) inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out) inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out) inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out) inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out) inception_3b_output_out = torch.cat( [inception_3b_relu_1x1_out, inception_3b_relu_3x3_out, inception_3b_relu_double_3x3_2_out, inception_3b_relu_pool_proj_out], 1) inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out) inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out) inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out) inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out) inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out) inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out) inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out) inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn( inception_3c_double_3x3_reduce_out) inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce( inception_3c_double_3x3_reduce_bn_out) inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out) inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out) inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out) inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out) inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out) inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out) inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out) inception_3c_output_out = torch.cat( [inception_3c_relu_3x3_out, inception_3c_relu_double_3x3_2_out, inception_3c_pool_out], 1) inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out) inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out) inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out) inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out) inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out) inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out) inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out) inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out) inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out) inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out) inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn( inception_4a_double_3x3_reduce_out) inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce( inception_4a_double_3x3_reduce_bn_out) inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out) inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out) inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out) inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out) inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out) inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out) inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out) inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out) inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out) inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out) inception_4a_output_out = torch.cat( [inception_4a_relu_1x1_out, inception_4a_relu_3x3_out, inception_4a_relu_double_3x3_2_out, inception_4a_relu_pool_proj_out], 1) inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out) inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out) inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out) inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out) inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out) inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out) inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out) inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out) inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out) inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out) inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn( inception_4b_double_3x3_reduce_out) inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce( inception_4b_double_3x3_reduce_bn_out) inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out) inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out) inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out) inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out) inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out) inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out) inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out) inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out) inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out) inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out) inception_4b_output_out = torch.cat( [inception_4b_relu_1x1_out, inception_4b_relu_3x3_out, inception_4b_relu_double_3x3_2_out, inception_4b_relu_pool_proj_out], 1) inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out) inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out) inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out) inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out) inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out) inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out) inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out) inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out) inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out) inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out) inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn( inception_4c_double_3x3_reduce_out) inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce( inception_4c_double_3x3_reduce_bn_out) inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out) inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out) inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out) inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out) inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out) inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out) inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out) inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out) inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out) inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out) inception_4c_output_out = torch.cat( [inception_4c_relu_1x1_out, inception_4c_relu_3x3_out, inception_4c_relu_double_3x3_2_out, inception_4c_relu_pool_proj_out], 1) inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out) inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out) inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out) inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out) inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out) inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out) inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out) inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out) inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out) inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out) inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn( inception_4d_double_3x3_reduce_out) inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce( inception_4d_double_3x3_reduce_bn_out) inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out) inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out) inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out) inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out) inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out) inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out) inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out) inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out) inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out) inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out) inception_4d_output_out = torch.cat( [inception_4d_relu_1x1_out, inception_4d_relu_3x3_out, inception_4d_relu_double_3x3_2_out, inception_4d_relu_pool_proj_out], 1) inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out) inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out) inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out) inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out) inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out) inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out) inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out) inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn( inception_4e_double_3x3_reduce_out) inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce( inception_4e_double_3x3_reduce_bn_out) inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out) inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out) inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out) inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out) inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out) inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out) inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out) inception_4e_output_out = torch.cat( [inception_4e_relu_3x3_out, inception_4e_relu_double_3x3_2_out, inception_4e_pool_out], 1) inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out) inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out) inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out) inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out) inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out) inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out) inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out) inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out) inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out) inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out) inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn( inception_5a_double_3x3_reduce_out) inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce( inception_5a_double_3x3_reduce_bn_out) inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out) inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out) inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out) inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out) inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out) inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out) inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out) inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out) inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out) inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out) inception_5a_output_out = torch.cat( [inception_5a_relu_1x1_out, inception_5a_relu_3x3_out, inception_5a_relu_double_3x3_2_out, inception_5a_relu_pool_proj_out], 1) inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out) inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out) inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out) inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out) inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out) inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out) inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out) inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out) inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out) inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out) inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn( inception_5b_double_3x3_reduce_out) inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce( inception_5b_double_3x3_reduce_bn_out) inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out) inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out) inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out) inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out) inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out) inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out) inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out) inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out) inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out) inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out) inception_5b_output_out = torch.cat( [inception_5b_relu_1x1_out, inception_5b_relu_3x3_out, inception_5b_relu_double_3x3_2_out, inception_5b_relu_pool_proj_out], 1) return inception_5b_output_out def logits(self, features): x = F.adaptive_max_pool2d(features, output_size=1) x = x.view(x.size(0), -1) return x def forward(self, input): x = self.features(input) x = self.logits(x) return x def load_param(self, model_path): param_dict = torch.load(model_path) for i in param_dict: if 'last_linear' in i: continue self.state_dict()[i].copy_(param_dict[i]) ================================================ FILE: ret_benchmark/modeling/backbone/build.py ================================================ from ret_benchmark.modeling.registry import BACKBONES from .bninception import BNInception from .resnet import ResNet50 def build_backbone(cfg): assert cfg.MODEL.BACKBONE.NAME in BACKBONES, \ f"backbone {cfg.MODEL.BACKBONE} is not registered in registry : {BACKBONES.keys()}" return BACKBONES[cfg.MODEL.BACKBONE.NAME]() ================================================ FILE: ret_benchmark/modeling/backbone/resnet.py ================================================ from __future__ import absolute_import, division, print_function import torch import torch.nn as nn import torchvision.models as models from ret_benchmark.modeling import registry @registry.BACKBONES.register('resnet50') class ResNet50(nn.Module): def __init__(self): super(ResNet50, self).__init__() self.model = models.resnet50(pretrained=True) for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): module.eval() module.train = lambda _: None def forward(self, x): x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) x = self.model.maxpool(x) x = self.model.layer1(x) x = self.model.layer2(x) x = self.model.layer3(x) x = self.model.layer4(x) x = self.model.avgpool(x) x = x.view(x.size(0), -1) # x = self.model.fc(x) --remove return x def load_param(self, model_path): param_dict = torch.load(model_path) for i in param_dict: if 'last_linear' in i: continue self.model.state_dict()[i].copy_(param_dict[i]) ================================================ FILE: ret_benchmark/modeling/build.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import os from collections import OrderedDict import torch from torch.nn.modules import Sequential from .backbone import build_backbone from .heads import build_head def build_model(cfg): backbone = build_backbone(cfg) head = build_head(cfg) model = Sequential(OrderedDict([ ('backbone', backbone), ('head', head) ])) if cfg.MODEL.PRETRAIN == 'imagenet': print('Loading imagenet pretrianed model ...') pretrained_path = os.path.expanduser(cfg.MODEL.PRETRIANED_PATH[cfg.MODEL.BACKBONE.NAME]) model.backbone.load_param(pretrained_path) elif os.path.exists(cfg.MODEL.PRETRAIN): ckp = torch.load(cfg.MODEL.PRETRAIN) model.load_state_dict(ckp['model']) return model ================================================ FILE: ret_benchmark/modeling/heads/__init__.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from .build import build_head ================================================ FILE: ret_benchmark/modeling/heads/build.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from ret_benchmark.modeling.registry import HEADS from .linear_norm import LinearNorm def build_head(cfg): assert cfg.MODEL.HEAD.NAME in HEADS, f"head {cfg.MODEL.HEAD.NAME} is not defined" return HEADS[cfg.MODEL.HEAD.NAME](cfg, in_channels=1024 if cfg.MODEL.BACKBONE.NAME == 'bninception' else 2048) ================================================ FILE: ret_benchmark/modeling/heads/linear_norm.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import torch from torch import nn from ret_benchmark.modeling.registry import HEADS from ret_benchmark.utils.init_methods import weights_init_kaiming @HEADS.register('linear_norm') class LinearNorm(nn.Module): def __init__(self, cfg, in_channels): super(LinearNorm, self).__init__() self.fc = nn.Linear(in_channels, cfg.MODEL.HEAD.DIM) self.fc.apply(weights_init_kaiming) def forward(self, x): x = self.fc(x) x = nn.functional.normalize(x, p=2, dim=1) return x ================================================ FILE: ret_benchmark/modeling/registry.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. from ret_benchmark.utils.registry import Registry BACKBONES = Registry() HEADS = Registry() ================================================ FILE: ret_benchmark/modeling/xbm.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import torch import tqdm from ret_benchmark.data.build import build_memory_data class XBM: def __init__(self, cfg, model): self.ratio = cfg.MEMORY.RATIO # init memory self.feats = list() self.labels = list() self.indices = list() model.train() for images, labels, indices in build_memory_data(cfg): with torch.no_grad(): feat = model(images.cuda()) self.feats.append(feat) self.labels.append(labels.cuda()) self.indices.append(indices.cuda()) self.feats = torch.cat(self.feats, dim=0) self.labels = torch.cat(self.labels, dim=0) self.indices = torch.cat(self.indices, dim=0) # if memory_ratio != 1.0 -> random sample init queue_mask to mimic fixed queue size if self.ratio != 1.0: rand_init_idx = torch.randperm(int(self.indices.shape[0] * self.ratio)).cuda() self.queue_mask = self.indices[rand_init_idx] def enqueue_dequeue(self, feats, indices): self.feats.data[indices] = feats if self.ratio != 1.0: # enqueue self.queue_mask = torch.cat((self.queue_mask, indices.cuda()), dim=0) # dequeue self.queue_mask = self.queue_mask[-int(self.indices.shape[0] * self.ratio):] def get(self): if self.ratio != 1.0: return self.feats[self.queue_mask], self.labels[self.queue_mask] else: return self.feats, self.labels ================================================ FILE: ret_benchmark/solver/__init__.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. from .build import build_optimizer from .build import build_lr_scheduler from .lr_scheduler import WarmupMultiStepLR ================================================ FILE: ret_benchmark/solver/build.py ================================================ import torch from .lr_scheduler import WarmupMultiStepLR def build_optimizer(cfg, model): params = [] for key, value in model.named_parameters(): if not value.requires_grad: continue lr_mul = 1.0 if "backbone" in key: lr_mul = 0.1 params += [{"params": [value], "lr_mul": lr_mul}] optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) return optimizer def build_lr_scheduler(cfg, optimizer): return WarmupMultiStepLR( optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, warmup_factor=cfg.SOLVER.WARMUP_FACTOR, warmup_iters=cfg.SOLVER.WARMUP_ITERS, warmup_method=cfg.SOLVER.WARMUP_METHOD, ) ================================================ FILE: ret_benchmark/solver/lr_scheduler.py ================================================ from bisect import bisect_right import torch class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): def __init__( self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, warmup_iters=500, warmup_method="linear", last_epoch=-1, ): if not list(milestones) == sorted(milestones): raise ValueError( "Milestones should be a list of" " increasing integers. Got {}", milestones, ) if warmup_method not in ("constant", "linear"): raise ValueError( "Only 'constant' or 'linear' warmup_method accepted" "got {}".format(warmup_method) ) self.milestones = milestones self.gamma = gamma self.warmup_factor = warmup_factor self.warmup_iters = warmup_iters self.warmup_method = warmup_method super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) def get_lr(self): warmup_factor = 1 if self.last_epoch < self.warmup_iters: if self.warmup_method == "constant": warmup_factor = self.warmup_factor elif self.warmup_method == "linear": alpha = float(self.last_epoch) / self.warmup_iters warmup_factor = self.warmup_factor * (1 - alpha) + alpha return [ base_lr * warmup_factor * self.gamma ** bisect_right( self.milestones, self.last_epoch ) for base_lr in self.base_lrs ] ================================================ FILE: ret_benchmark/utils/checkpoint.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import logging import os import torch from ret_benchmark.utils.model_serialization import load_state_dict class Checkpointer(object): def __init__( self, model, optimizer=None, scheduler=None, save_dir="", save_to_disk=None, logger=None, ): self.model = model self.optimizer = optimizer self.scheduler = scheduler self.save_dir = save_dir self.save_to_disk = save_to_disk if logger is None: logger = logging.getLogger(__name__) self.logger = logger def save(self, name): if not self.save_dir: return data = {} data["model"] = self.model.state_dict() if self.optimizer is not None: data["optimizer"] = self.optimizer.state_dict() if self.scheduler is not None: data["scheduler"] = self.scheduler.state_dict() save_file = os.path.join(self.save_dir, "{}.pth".format(name)) self.logger.info("Saving checkpoint to {}".format(save_file)) torch.save(data, save_file) def load(self, f=None): if self.has_checkpoint(): # override argument with existing checkpoint f = self.get_checkpoint_file() if not f: # no checkpoint could be found self.logger.info("No checkpoint found. Initializing model from scratch") return {} self.logger.info("Loading checkpoint from {}".format(f)) checkpoint = self._load_file(f) self._load_model(checkpoint) if "optimizer" in checkpoint and self.optimizer: self.logger.info("Loading optimizer from {}".format(f)) self.optimizer.load_state_dict(checkpoint.pop("optimizer")) if "scheduler" in checkpoint and self.scheduler: self.logger.info("Loading scheduler from {}".format(f)) self.scheduler.load_state_dict(checkpoint.pop("scheduler")) # return any further checkpoint data return checkpoint def has_checkpoint(self): save_file = os.path.join(self.save_dir, "last_checkpoint") return os.path.exists(save_file) def get_checkpoint_file(self): save_file = os.path.join(self.save_dir, "last_checkpoint") try: with open(save_file, "r") as f: last_saved = f.read() last_saved = last_saved.strip() except IOError: # if file doesn't exist, maybe because it has just been # deleted by a separate process last_saved = "" return last_saved def tag_last_checkpoint(self, last_filename): save_file = os.path.join(self.save_dir, "last_checkpoint") with open(save_file, "w") as f: f.write(last_filename) def _load_file(self, f): return torch.load(f, map_location=torch.device("cpu")) def _load_model(self, checkpoint): load_state_dict(self.model, checkpoint.pop("model")) ================================================ FILE: ret_benchmark/utils/config_util.py ================================================ from __future__ import (absolute_import, division, print_function, unicode_literals) import copy import os from ret_benchmark.config import cfg as g_cfg def get_config_root_path(): ''' Path to configs for unit tests ''' # cur_file_dir is root/tests/env_tests cur_file_dir = os.path.dirname(os.path.abspath(os.path.realpath(__file__))) ret = os.path.dirname(os.path.dirname(cur_file_dir)) ret = os.path.join(ret, "configs") return ret def load_config(rel_path): ''' Load config from file path specified as path relative to config_root ''' cfg_path = os.path.join(get_config_root_path(), rel_path) return load_config_from_file(cfg_path) def load_config_from_file(file_path): ''' Load config from file path specified as absolute path ''' ret = copy.deepcopy(g_cfg) ret.merge_from_file(file_path) return ret ================================================ FILE: ret_benchmark/utils/feat_extractor.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import torch import numpy as np def feat_extractor(model, data_loader, logger=None): model.eval() feats = list() for i, batch in enumerate(data_loader): imgs = batch[0].cuda() with torch.no_grad(): out = model(imgs).data.cpu().numpy() feats.append(out) if logger is not None and (i + 1) % 100 == 0: logger.debug(f'Extract Features: [{i + 1}/{len(data_loader)}]') del out feats = np.vstack(feats) return feats ================================================ FILE: ret_benchmark/utils/freeze_bn.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. # Batch Norm Freezer # Note: adds an additional 2% improvement on CUB (on others benchmarks, it brings no effect) def set_bn_eval(m): classname = m.__class__.__name__ if classname.find('BatchNorm') != -1: m.eval() ================================================ FILE: ret_benchmark/utils/img_reader.py ================================================ import os.path as osp from PIL import Image def read_image(img_path, mode='RGB'): """Keep reading image until succeed. This can avoid IOError incurred by heavy IO process.""" got_img = False if not osp.exists(img_path): raise IOError(f"{img_path} does not exist") while not got_img: try: img = Image.open(img_path).convert("RGB") if mode == "BGR": r, g, b = img.split() img = Image.merge("RGB", (b, g, r)) got_img = True except IOError: print(f"IOError incurred when reading '{img_path}'. Will redo.") pass return img ================================================ FILE: ret_benchmark/utils/init_methods.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import torch from torch import nn def weights_init_kaiming(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') nn.init.constant_(m.bias, 0.0) elif classname.find('Conv') != -1: nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') if m.bias is not None: nn.init.constant_(m.bias, 0.0) elif classname.find('BatchNorm') != -1: if m.affine: nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.bias, 0.0) def weights_init_classifier(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: nn.init.normal_(m.weight, std=0.001) if m.bias is not None: nn.init.constant_(m.bias, 0.0) ================================================ FILE: ret_benchmark/utils/logger.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import os import sys import logging _streams = { "stdout": sys.stdout } def setup_logger(name: str, level: int, stream: str = "stdout") -> logging.Logger: global _streams if stream not in _streams: log_folder = os.path.dirname(stream) os.makedirs(log_folder, exist_ok=True) _streams[stream] = open(stream, 'w') logger = logging.getLogger(name) logger.propagate = False logger.setLevel(level) sh = logging.StreamHandler(stream=_streams[stream]) sh.setLevel(level) formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") sh.setFormatter(formatter) logger.addHandler(sh) return logger ================================================ FILE: ret_benchmark/utils/metric_logger.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. from collections import defaultdict from collections import deque import torch class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20): self.deque = deque(maxlen=window_size) self.series = [] self.total = 0.0 self.count = 0 def update(self, value): self.deque.append(value) self.series.append(value) self.count += 1 self.total += value @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque)) return d.mean().item() @property def global_avg(self): return self.total / self.count class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) ) return self.delimiter.join(loss_str) ================================================ FILE: ret_benchmark/utils/model_serialization.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. from collections import OrderedDict import logging import torch def align_and_update_state_dicts(model_state_dict, loaded_state_dict): """ Strategy: suppose that the models that we will create will have prefixes appended to each of its keys, for example due to an extra level of nesting that the original pre-trained weights from ImageNet won't contain. For example, model.state_dict() might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains res2.conv1.weight. We thus want to match both parameters together. For that, we look for each model weight, look among all loaded keys if there is one that is a suffix of the current weight name, and use it if that's the case. If multiple matches exist, take the one with longest size of the corresponding name. For example, for the same model as before, the pretrained weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, we want to match backbone[0].body.conv1.weight to conv1.weight, and backbone[0].body.res2.conv1.weight to res2.conv1.weight. """ current_keys = sorted(list(model_state_dict.keys())) loaded_keys = sorted(list(loaded_state_dict.keys())) # get a matrix of string matches, where each (i, j) entry correspond to the size of the # loaded_key string, if it matches match_matrix = [ len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys ] match_matrix = torch.as_tensor(match_matrix).view( len(current_keys), len(loaded_keys) ) max_match_size, idxs = match_matrix.max(1) # remove indices that correspond to no-match idxs[max_match_size == 0] = -1 # used for logging max_size = max([len(key) for key in current_keys]) if current_keys else 1 max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1 log_str_template = "{: <{}} loaded from {: <{}} of shape {}" logger = logging.getLogger(__name__) for idx_new, idx_old in enumerate(idxs.tolist()): if idx_old == -1: continue key = current_keys[idx_new] key_old = loaded_keys[idx_old] model_state_dict[key] = loaded_state_dict[key_old] logger.info( log_str_template.format( key, max_size, key_old, max_size_loaded, tuple(loaded_state_dict[key_old].shape), ) ) def strip_prefix_if_present(state_dict, prefix): keys = sorted(state_dict.keys()) if not all(key.startswith(prefix) for key in keys): return state_dict stripped_state_dict = OrderedDict() for key, value in state_dict.items(): stripped_state_dict[key.replace(prefix, "")] = value return stripped_state_dict def load_state_dict(model, loaded_state_dict): model_state_dict = model.state_dict() # if the state_dict comes from a model that was wrapped in a # DataParallel or DistributedDataParallel during serialization, # remove the "module" prefix before performing the matching loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.") align_and_update_state_dicts(model_state_dict, loaded_state_dict) # use strict loading model.load_state_dict(model_state_dict) ================================================ FILE: ret_benchmark/utils/registry.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. def _register_generic(module_dict, module_name, module): assert module_name not in module_dict module_dict[module_name] = module class Registry(dict): ''' A helper class for managing registering modules, it extends a dictionary and provides a register functions. Eg. creeting a registry: some_registry = Registry({"default": default_module}) There're two ways of registering new modules: 1): normal way is just calling register function: def foo(): ... some_registry.register("foo_module", foo) 2): used as decorator when declaring the module: @some_registry.register("foo_module") @some_registry.register("foo_modeul_nickname") def foo(): ... Access of module is just like using a dictionary, eg: f = some_registry["foo_modeul"] ''' def __init__(self, *args, **kwargs): super(Registry, self).__init__(*args, **kwargs) def register(self, module_name, module=None): # used as function call if module is not None: _register_generic(self, module_name, module) return # used as decorator def register_fn(fn): _register_generic(self, module_name, fn) return fn return register_fn ================================================ FILE: scripts/prepare_cub.sh ================================================ #!/bin/bash set -e CUB_ROOT='resource/datasets/CUB_200_2011/' CUB_DATA='http://www.vision.caltech.edu.s3-us-west-2.amazonaws.com/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' if [[ ! -d "${CUB_ROOT}" ]]; then mkdir -p resource/datasets pushd resource/datasets echo "Downloading CUB_200_2011 data-set..." wget ${CUB_DATA} tar -zxf CUB_200_2011.tgz popd fi # Generate train.txt and test.txt splits echo "Generating the train.txt/test.txt split files" python scripts/split_cub_for_ms_loss.py ================================================ FILE: scripts/run_cub.sh ================================================ #!/bin/bash OUT_DIR="output" if [[ ! -d "${OUT_DIR}" ]]; then echo "Creating output dir for training : ${OUT_DIR}" mkdir ${OUT_DIR} fi CUDA_VISIBLE_DEVICES=0 python3.6 tools/main.py --cfg configs/example.yaml ================================================ FILE: scripts/run_cub_margin.sh ================================================ #!/bin/bash OUT_DIR="output_margin" if [[ ! -d "${OUT_DIR}" ]]; then echo "Creating output dir for training : ${OUT_DIR}" mkdir ${OUT_DIR} fi CUDA_VISIBLE_DEVICES=0 python3.6 tools/main.py --cfg configs/example_margin.yaml ================================================ FILE: scripts/split_cub_for_ms_loss.py ================================================ cub_root = 'resource/datasets/CUB_200_2011/' images_file = cub_root + 'images.txt' train_file = cub_root + 'train.txt' test_file = cub_root + 'test.txt' def main(): train = [] test = [] with open(images_file) as f_img: for l_img in f_img: i, fname = l_img.split() label = int(fname.split('.', 1)[0]) if label <= 100: train.append((fname, label - 1)) # labels 0 ... 99 (0-based labels for margin_loss) else: test.append((fname, label - 1)) # labels 100 ... 199 for f, v in [(train_file, train), (test_file, test)]: with open(f, 'w') as tf: for fname, label in v: print("images/{},{}".format(fname, label), file=tf) if __name__ == '__main__': main() ================================================ FILE: setup.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import CppExtension requirements = ["torch", "torchvision"] setup( name="ret_benchmark", version="0.1", author="Malong Technologies", url="https://github.com/MalongTech/research-ms-loss", description="ms-loss", packages=find_packages(exclude=("configs", "tests")), install_requires=requirements, cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: tools/main.py ================================================ # Copyright (c) Malong Technologies Co., Ltd. # All rights reserved. # # Contact: github@malong.com # # This source code is licensed under the LICENSE file in the root directory of this source tree. import argparse import torch from ret_benchmark.config import cfg from ret_benchmark.data import build_data from ret_benchmark.engine.trainer import do_train from ret_benchmark.losses import build_loss from ret_benchmark.modeling import build_model from ret_benchmark.solver import build_lr_scheduler, build_optimizer from ret_benchmark.utils.logger import setup_logger from ret_benchmark.utils.checkpoint import Checkpointer def train(cfg): logger = setup_logger(name='Train', level=cfg.LOGGER.LEVEL) logger.info(cfg) model = build_model(cfg) device = torch.device(cfg.MODEL.DEVICE) model.to(device) criterion = build_loss(cfg) optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) train_loader = build_data(cfg, is_train=True) val_loader = build_data(cfg, is_train=False) logger.info(train_loader.dataset) logger.info(val_loader.dataset) arguments = dict() arguments["iteration"] = 0 checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD checkpointer = Checkpointer(model, optimizer, scheduler, cfg.SAVE_DIR) do_train( cfg, model, train_loader, val_loader, optimizer, scheduler, criterion, checkpointer, device, checkpoint_period, arguments, logger ) def parse_args(): """ Parse input arguments """ parser = argparse.ArgumentParser(description='Train a retrieval network') parser.add_argument( '--cfg', dest='cfg_file', help='config file', default=None, type=str) return parser.parse_args() if __name__ == '__main__': args = parse_args() cfg.merge_from_file(args.cfg_file) train(cfg)