Repository: sail-sg/iFormer Branch: main Commit: 725d8e7f455b Files: 21 Total size: 228.8 KB Directory structure: gitextract_p3kgg1_k/ ├── LICENSE ├── MANIFEST.in ├── README.md ├── checkpoint/ │ ├── iformer_base/ │ │ ├── args.yaml │ │ └── summary.csv │ ├── iformer_large/ │ │ ├── args.yaml │ │ └── summary.csv │ └── iformer_small/ │ ├── args.yaml │ └── summary.csv ├── checkpoint_384/ │ ├── iformer_base_384/ │ │ ├── args.yaml │ │ └── summary.csv │ ├── iformer_large_384/ │ │ ├── args.yaml │ │ └── summary.csv │ └── iformer_small_384/ │ ├── args.yaml │ └── summary.csv ├── fine-tune.py ├── models/ │ ├── __init__.py │ └── inception_transformer.py ├── setup.cfg ├── train.py └── validate.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2019 Ross Wightman Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: MANIFEST.in ================================================ include timm/models/pruned/*.txt ================================================ FILE: README.md ================================================ # iFormer: [Inception Transformer](http://arxiv.org/abs/2205.12956) (NeurIPS 2022 Oral) This is a PyTorch implementation of iFormer proposed by our paper "[Inception Transformer](http://arxiv.org/abs/2205.12956)". ## Image Classification ### 1. Requirements torch>=1.7.0; torchvision>=0.8.1; timm==0.5.4; fvcore; [apex-amp](https://github.com/NVIDIA/apex) (if you want to use fp16); data prepare: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4). ``` │imagenet/ ├──train/ │ ├── n01440764 │ │ ├── n01440764_10026.JPEG │ │ ├── n01440764_10027.JPEG │ │ ├── ...... │ ├── ...... ├──val/ │ ├── n01440764 │ │ ├── ILSVRC2012_val_00000293.JPEG │ │ ├── ILSVRC2012_val_00002138.JPEG │ │ ├── ...... │ ├── ...... ``` ### Main results on ImageNet-1K | Model | #params | FLOPs | Image resolution | acc@1| Model | | :--- | :---: | :---: | :---: | :---: | :---: | | iFormer-S | 20M | 4.8G | 224 | 83.4 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_small.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_small/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_small/summary.csv) | | iFormer-B | 48M | 9.4G | 224 | 84.6 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_base.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_base/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_base/summary.csv) | | iFormer-L | 87M | 14.0G | 224 | 84.8 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_large.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_large/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_large/summary.csv) | Fine-tuning Results with larger resolution (384x384) on ImageNet-1K | Model | #params | FLOPs | Image resolution | acc@1| Model | | :--- | :---: | :---: | :---: | :---: | :---: | | iFormer-S | 20M | 16.1G | 384 | 84.6 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_small_384.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_small_384/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_small_384/summary.csv) | | iFormer-B | 48M | 30.5G | 384 | 85.7 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_base_384.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_base_384/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_base_384/summary.csv) | | iFormer-L | 87M | 45.3G | 384 | 85.8 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_large_384.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_large_384/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_large_384/summary.csv) | ### Training Train iformer_small on 224 ```bash python -m torch.distributed.launch --nproc_per_node=8 train.py /dataset/imagenet \ --model iformer_small -b 128 --epochs 300 --img-size 224 --drop-path 0.2 --lr 1e-3 \ --weight-decay 0.05 --aa rand-m9-mstd0.5-inc1 --warmup-lr 1e-6 --warmup-epochs 5 \ --output checkpoint --min-lr 1e-6 --experiment iformer_small ``` Finetune on 384 based on the pretrained checkpoint on 224 ```bash python -m torch.distributed.launch --nproc_per_node=8 fine-tune.py /dataset/imagenet \ --model iformer_small_384 -b 64 --lr 1e-5 --min-lr 1e-6 --warmup-lr 2e-8 --warmup-epochs 0 \ --epochs 20 --img-size 384 --drop-path 0.3 --weight-decay 1e-8 --mixup 0.1 --cutmix 0.1 \ --cooldown-epochs 10 --aa rand-m9-mstd0.5-inc1 --clip-grad 1.0 --output checkpoint_fine \ --initial-checkpoint checkpoint/iformer_small/model_best.pth.tar \ --experiment iformer_small_384 ``` ### Validation ```bash python validate.py /dataset/imagenet --model iformer_small --checkpoint checkpoint/iformer_small/model_best.pth.tar ``` ## Object Detection and Instance Segmentation All models are based on Mask R-CNN and trained by 1x  training schedule. | Backbone | #Param. | FLOPs | box mAP | mask mAP | |:---------:|:-------:|:-----:|:-------:|:--------:| | iFormer-S | 40M | 263G | 46.2 | 41.9 | | iFormer-B | 67M | 351G | 48.3 | 43.3 | ## Semantic Segmentation | Backbone | Method | #Param. | FLOPs | mIoU | |:---------:|---------|:-------:|:-----:|:----:| | iFormer-S | FPN | 24M | 181G | 48.6 | | iFormer-S | Upernet | 49M | 938G | 48.4 | ## Bibtex ``` @inproceedings{ si2022inception, title={Inception Transformer}, author={Chenyang Si and Weihao Yu and Pan Zhou and Yichen Zhou and Xinchao Wang and Shuicheng YAN}, booktitle={Advances in Neural Information Processing Systems}, year={2022} } ``` ## Acknowledgment Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works. [pytorch-image-models](https://github.com/rwightman/pytorch-image-models), [mmdetection](https://github.com/open-mmlab/mmdetection), [mmsegmentation](https://github.com/open-mmlab/mmsegmentation). Besides, Weihao Yu would like to thank TPU Research Cloud (TRC) program for the support of partial computational resources. ================================================ FILE: checkpoint/iformer_base/args.yaml ================================================ aa: rand-m9-mstd0.5-inc1 amp: false apex_amp: false aug_repeats: 3 aug_splits: 0 batch_size: 64 bce_loss: false bce_target_thresh: null bn_eps: null bn_momentum: null bn_tf: false channels_last: false checkpoint_hist: 1 class_map: '' clip_grad: 1.0 clip_mode: norm color_jitter: 0.4 cooldown_epochs: 10 crop_pct: null cutmix: 1.0 cutmix_minmax: null data_dir: /dataset/imagenet-raw dataset: '' dataset_download: false decay_epochs: 30.0 decay_rate: 0.1 dist_bn: reduce drop: 0.0 drop_block: null drop_connect: null drop_path: 0.4 embed_dim: 384 epoch_repeats: 0.0 epochs: 300 eval_metric: top1 experiment: iformer_base gp: null hflip: 0.5 img_size: 224 initial_checkpoint: '' input_size: null interpolation: '' jsd_loss: false local_rank: 0 log_interval: 50 log_wandb: false lr: 0.001 lr_cycle_decay: 0.5 lr_cycle_limit: 1 lr_cycle_mul: 1.0 lr_k_decay: 1.0 lr_noise: null lr_noise_pct: 0.67 lr_noise_std: 1.0 mean: null min_lr: 1.0e-05 mixup: 0.8 mixup_mode: batch mixup_off_epoch: 0 mixup_prob: 1.0 mixup_switch_prob: 0.5 model: iformer_base model_ema: false model_ema_decay: 0.9998 model_ema_force_cpu: false momentum: 0.9 native_amp: false no_aug: false no_ddp_bb: false no_prefetcher: true no_resume_opt: false num_classes: null opt: adamw opt_betas: null opt_eps: 1.0e-08 output: checkpoint patience_epochs: 10 pin_mem: false port: '25500' pretrained: false ratio: - 0.75 - 1.3333333333333333 recount: 1 recovery_interval: 0 remode: pixel reprob: 0.25 resplit: false resume: '' save_images: false scale: - 0.08 - 1.0 sched: cosine seed: 42 smoothing: 0.1 split_bn: false start_epoch: null std: null sync_bn: false torchscript: false train_interpolation: random train_split: train tta: 0 use_multi_epochs_loader: false val_split: validation validation_batch_size: null vflip: 0.0 warmup_epochs: 5 warmup_lr: 1.0e-06 weight_decay: 0.05 worker_seeding: all workers: 10 ================================================ FILE: checkpoint/iformer_base/summary.csv ================================================ epoch,train_loss,eval_loss,eval_top1,eval_top5 0,6.908455812014067,6.860800276947021,0.3619999967956543,1.5300000119018555 1,6.665643343558679,5.8470182475280765,3.750000009765625,12.358000017089843 2,6.3399158624502325,4.92742544342041,11.036000020751953,27.733999997558595 3,6.007291243626521,4.11241763496399,20.144000025634767,42.53600006103515 4,5.740854611763587,3.553922503890991,27.132000051269532,53.624000043945315 5,5.421774644118089,3.008195775489807,37.086000013427736,64.2180000024414 6,5.238544537470891,2.725795027923584,42.46799998046875,69.79000006591797 7,5.00539570588332,2.439262678527832,48.0880001171875,74.49400002685547 8,4.882717260947595,2.2806560862350462,51.30599999267578,77.654000078125 9,4.747562701885517,2.09722495513916,54.68799995361328,80.28000007080078 10,4.535913467407227,1.9907448290634155,57.23400012695313,82.21600009033203 11,4.489044281152578,1.9909129524230957,58.418000046386716,83.2179999609375 12,4.439941516289344,1.8225921615219116,60.41000010986328,84.44999995117188 13,4.332764918987568,1.8817005332565309,61.10000019775391,85.06000002929687 14,4.261882268465483,1.6659569248199464,63.4940000024414,86.6240000024414 15,4.186580703808711,1.6334883233642579,63.90200006591797,86.71199992675781 16,4.21221111370967,1.6058608751106263,65.52199997314453,87.83800005126953 17,4.1268850198158855,1.525306345729828,65.95999991699219,88.33599999511719 18,4.047768574494582,1.518114146270752,66.62600010009766,88.70400015136718 19,4.040863623985877,1.525582196998596,67.52799990234375,89.14200002441406 20,4.001404560529268,1.484820670261383,67.73800007080078,89.16000020263672 21,3.9741526475319495,1.4245139252471923,68.33000017822266,89.64000015136719 22,3.902811747330886,1.4404746714973449,68.7499999951172,89.90000004638672 23,3.888140925994286,1.452209776649475,68.85600004882812,89.92599997070313 24,3.8569280092532816,1.355765485534668,69.79400004638671,90.53599996582031 25,3.874772942983187,1.322539180431366,69.96800012451172,90.5139998876953 26,3.824866845057561,1.3584467635345459,70.43400001953125,90.71600004638672 27,3.815111151108375,1.3259017392158507,70.89200012451172,91.10200007080078 28,3.7768686551314135,1.2894453044128418,71.04800004882813,91.17800004394532 29,3.7479228789989767,1.3424655978775024,71.19000008789062,91.07599999023438 30,3.7662950295668383,1.2917246441841126,71.43599998779297,91.60400004150391 31,3.754122550670917,1.3602314675521852,71.98599995605468,91.6720001196289 32,3.7166288174115696,1.2361285284996033,72.2459999584961,91.74399998779298 33,3.7051904568305383,1.2548866311454774,72.2139999609375,91.86599998779297 34,3.6960658752001248,1.2564913172149659,71.99000001708984,91.8040001220703 35,3.709755466534541,1.2993868688964845,72.33400013916015,91.94400011962891 36,3.695227127808791,1.2078956381416321,72.6780001196289,92.0100000390625 37,3.641167613176199,1.237737746334076,72.6480000415039,92.0580001147461 38,3.6478970601008487,1.1785106216049195,73.3219999633789,92.4960000415039 39,3.640158689939059,1.202273616809845,73.29000001708984,92.45399999023438 40,3.7028306814340444,1.2195009224510194,73.21800003662109,92.23600014404298 41,3.6155638694763184,1.2656590075683594,73.55800003417968,92.4960000390625 42,3.6116727957358727,1.2079661503982544,73.4180000415039,92.51400006591797 43,3.6359497217031627,1.1636047567939758,73.47799993408204,92.46399996337891 44,3.644525775542626,1.2462298067855835,73.5419999609375,92.73400001464844 45,3.5377747187247643,1.2198820894241333,73.96600000976562,92.71400004150391 46,3.577638406019944,1.1579090527915954,74.22400000488281,92.93800008789063 47,3.578922601846548,1.1946468775939942,74.06000000488281,92.89200006591797 48,3.6083018688055186,1.260998232460022,74.14000005859376,92.86400019775391 49,3.5781111992322483,1.1498024686813355,74.71200008789063,93.00600000976563 50,3.556312166727506,1.2316087501144408,74.506,93.2479999584961 51,3.5926481760465183,1.145808798084259,74.93599995117188,93.1839999609375 52,3.5503257513046265,1.2279409123039247,74.85000013183594,92.91600014160156 53,3.548908096093398,1.1786081212425232,75.20799995361328,93.3060000390625 54,3.5557837944764357,1.2336573403549194,74.51599998291016,93.11999993896484 55,3.4652767731593204,1.1627975540542603,75.40400003417969,93.35799999267579 56,3.479139502231891,1.138182176322937,75.28800010986328,93.4080001171875 57,3.5164922384115367,1.1406619202041626,75.22600006103515,93.27399998779296 58,3.501664170852074,1.1390633061408997,75.40600000732422,93.3199999633789 59,3.491689085960388,1.1517196659469604,75.43400013427734,93.35000009521484 60,3.5151750307816725,1.1642712952232361,75.45400010986329,93.34399998535156 61,3.4816090877239523,1.1658449925994874,75.48800003173828,93.6820001171875 62,3.4707261690726647,1.1238778091812134,75.67400008300781,93.52800009033203 63,3.4746356377234826,1.147294167804718,75.82200003417968,93.59600001220703 64,3.459448658503019,1.174054825153351,75.43000003417968,93.5079999609375 65,3.4489993590575,1.1662658145713807,75.70599995605468,93.57999990966798 66,3.485427141189575,1.1435979705429078,75.92599992675781,93.64600016601563 67,3.4828745401822605,1.134217993068695,76.18200002929687,93.7379999560547 68,3.45894560447106,1.156855950126648,75.72199998046875,93.55200006591797 69,3.44607728261214,1.0945456811714172,76.3340001611328,93.77799993408203 70,3.477376791147085,1.1079738793563843,76.24000010986327,93.91200016601563 71,3.5141904812592726,1.0928245606613158,76.30200008789062,93.69600013916016 72,3.410055022973281,1.0860164308166504,76.26800003173828,93.90199995849609 73,3.420612326035133,1.116220949783325,76.65600010498046,93.8940001171875 74,3.4254535344930797,1.1075060968399049,76.16600008300782,93.7419999584961 75,3.4103054358409,1.1158017666244506,76.32199995361329,93.9019998803711 76,3.476789043499873,1.0348516142463684,76.68800013671876,94.00600001220702 77,3.4337857228059034,1.0966742949867248,76.42800005371093,93.84000006103516 78,3.4188507887033315,1.0247036343955993,76.71600001464844,94.09599998779296 79,3.423209318747887,1.0942145678901671,76.81400018554687,94.01599993408203 80,3.3477271520174465,1.0325566992378234,76.62600013427735,93.98600003662109 81,3.4215016089952908,1.046407821083069,76.63400003662109,94.1180000366211 82,3.430701503386864,1.0959824050521851,77.01999987304687,94.0459999609375 83,3.4161557325950036,1.0931310430908203,76.86399993164062,93.97000000732422 84,3.3875768918257494,1.108051809272766,76.73400020996094,94.07200014160156 85,3.4196390096957865,1.14773931640625,76.86000002929687,94.1320000366211 86,3.437829402776865,1.0369087854003907,77.00200010498047,94.10800000976562 87,3.385591616997352,1.0996613207626342,77.08599995361328,94.2460000366211 88,3.4244728546876173,1.062659384174347,77.17800010742188,94.22800000976562 89,3.388879519242507,1.0323741394615173,77.2419999243164,94.26999990966797 90,3.378705382347107,1.1156270475387573,77.27599994628906,94.21800003662109 91,3.33666560283074,1.0247049710083007,77.25200018066407,94.21600001220703 92,3.386662767483638,1.0646427408790589,77.43999997802734,94.20000000976563 93,3.341835297071017,1.063674129047394,77.41000015625,94.33999998291016 94,3.3956225651961107,1.0314554983901978,77.34999997558593,94.32600013916016 95,3.338684366299556,1.1056486014938354,77.28800005371093,94.18600016357422 96,3.328054749048673,1.0338896635437012,77.83000005371093,94.4480000366211 97,3.3632744550704956,1.0532264839172363,77.81800002929687,94.45999998291016 98,3.3554288148880005,1.0673184670639038,77.83000002197265,94.48799992919922 99,3.3159962709133444,1.1167875634765625,77.59199997070313,94.43600013916016 100,3.3194725880256066,1.036779229888916,77.53200012939453,94.40800008789063 101,3.3196748128304114,0.9950653231811524,77.88199995605468,94.39600003662109 102,3.294141851938688,1.0245422155380248,78.03399997070312,94.53000011230469 103,3.3244126851742086,1.0557471356391908,77.7660000024414,94.47399998535157 104,3.3125311044546275,1.0209960681152344,78.218,94.63000006103516 105,3.3134930042120128,1.034874595451355,77.87000005371094,94.63400013916015 106,3.3416621409929714,1.022444501953125,77.88000011230469,94.48000003662109 107,3.2982225234691915,1.0652961749267578,77.80800005371094,94.52599993164063 108,3.2699141318981466,0.9846357432556152,78.00199992675782,94.64599993408203 109,3.334826551950895,1.0435322013473511,78.16800010253907,94.58799998291016 110,3.303573865156907,1.0032002863311769,78.34800002929687,94.69000000976563 111,3.3007822862038245,1.0601369972038268,78.02399995117187,94.53400008544922 112,3.2984123138281016,1.051457786140442,78.11000013183593,94.75999990478516 113,3.262780079474816,1.1014147911262513,78.36200005859375,94.7040000366211 114,3.28412912442134,1.0476678370285035,78.480000078125,94.75799998535156 115,3.240434399017921,1.0324708117103576,78.42200002929688,94.76400003417969 116,3.2967760287798367,0.9791300971794128,78.43800000244141,94.78400003417968 117,3.261618274908799,1.051374360847473,78.38000015625,94.84999998535156 118,3.2742426487115712,1.0074605080032348,78.4679999975586,94.82000006103516 119,3.276140946608323,1.0688135836410522,78.35999984375,94.60799990478516 120,3.2416304624997654,0.996600732421875,78.74800010498046,94.91000008544921 121,3.2056774359482985,1.0518232615089416,78.67,94.78000000976563 122,3.227355828652015,0.9713950177574158,78.79799997558594,94.97000006103515 123,3.2498799654153676,1.0024585261154175,78.842,94.94400006103515 124,3.246049715922429,0.9771295314407349,78.96599997314453,95.06600011230469 125,3.261179988200848,1.0789605331230163,78.8920000756836,95.08599990478515 126,3.254664714519794,1.004692148284912,79.15199997558594,95.04799998291016 127,3.2356422589375424,0.9843215970993042,78.88800013427735,94.97000003662109 128,3.2246823494250956,1.0390974626731873,79.14800004882812,95.04000006103516 129,3.2226843650524435,1.0247631069946288,79.1079999194336,94.97200000976562 130,3.2466600766548743,0.9585929667854309,79.3819999194336,95.14600006103515 131,3.242083659538856,0.9486254569244384,79.40000020507813,95.22400008544922 132,3.2279451351899366,0.9887630248260498,79.13200005126953,95.02400000732422 133,3.220690608024597,1.0240418742561341,79.25200007324219,95.21199990478516 134,3.218092579108018,0.9779359574508667,79.25000002441406,95.21999998291015 135,3.2125297693105845,0.9315142623901367,79.20600002441407,95.17600013671876 136,3.239817573474004,1.0313529789352418,79.226000078125,95.06800013916016 137,3.18045646410722,0.9608411916351318,79.4799999243164,95.29999993164063 138,3.1873541336793165,0.9543602143478394,79.43999997558593,95.18199998291016 139,3.165460522358234,1.048680834980011,79.42599994873046,95.1680000366211 140,3.1896323424119215,1.0906442671966552,79.2520000805664,95.24000000976562 141,3.1748633751502404,0.9517472584152221,79.57200015136719,95.26599995605469 142,3.1773726298258853,0.9701159319877625,79.45400012695312,95.21600000732421 143,3.159519048837515,1.0661436851692199,79.51800006835937,95.27800010986329 144,3.2158747819753795,0.9745243270874023,79.45800002685547,95.29599998291016 145,3.161927204865676,0.9890392238426209,79.53000012695313,95.23800000732422 146,3.1869288132740903,0.9883130007171631,79.784000078125,95.34600013916015 147,3.1633920211058397,0.9936535754013062,79.90000004882812,95.43400000732422 148,3.190845801280095,0.9889785444259643,79.71599997070312,95.39599990478516 149,3.1653112173080444,1.0283795692443847,79.9499999194336,95.36399992919922 150,3.166087508201599,0.9779253969573974,79.97,95.42999995605469 151,3.1479290081904483,1.0098531292724608,79.94000005371093,95.36200000732421 152,3.1536225997484646,0.9446434412956238,80.28599997070313,95.48000000732422 153,3.110550990471473,0.9584104975318909,79.865999921875,95.52199998046875 154,3.167463806959299,0.997737165107727,79.82200020751954,95.42799998291015 155,3.13789662031027,0.9710854097938537,80.03800002685547,95.44799998046875 156,3.1173960062173696,0.9749535838317871,80.34599997070312,95.50200000732421 157,3.1322424686872044,0.9212129807472229,80.262,95.50399992919922 158,3.1068271673642673,0.9729003804779053,80.38200004882812,95.52800018798828 159,3.1096662007845364,0.9109539186859131,80.596000078125,95.70000005859374 160,3.103513864370493,0.9322344965553284,80.41200009765625,95.57599998291016 161,3.132872122984666,0.882158300819397,80.38400020996093,95.57800003417968 162,3.1003327461389394,1.0276759369468689,80.59400004638672,95.57400000732422 163,3.107926056935237,0.9502820700454712,80.40000017578124,95.60600003417969 164,3.0697899781740627,0.9290867798805237,80.49800009765625,95.72800008544922 165,3.111760973930359,0.9023137127494812,80.51399999511719,95.69000000732422 166,3.0767463353964,0.9757205354499817,80.75000004638672,95.75400008544922 167,3.0687196346429677,0.8881613404273987,80.59600015625,95.73599998046875 168,3.1184717875260572,0.8646588974380494,80.87400010253906,95.79600013916016 169,3.090632071861854,0.9433945862197876,80.9039999975586,95.75199995605469 170,3.115909457206726,0.931336729221344,80.72200017578125,95.74200016357422 171,3.0933127036461463,0.9939709894943237,80.75599999755859,95.78200003417969 172,3.086327782044044,0.8836379608726501,80.86799991455078,95.96399992919922 173,3.077766941143916,0.9232083583641052,80.89999999511718,95.77400003662109 174,3.0573448768028846,0.9307572483062744,80.80800012695312,95.82199990478516 175,3.0762509199289174,0.8996614429855346,81.23000007568359,95.8099998803711 176,3.0396402340668898,0.9776417551994324,81.1700001220703,95.93000000488281 177,3.03109754048861,0.9624793661499024,81.14000007080078,95.93800010986328 178,3.040617612692026,0.9664762647247315,81.03800004882812,95.83199998291016 179,3.0303638715010424,0.894894326210022,81.37600004638672,96.00599992919922 180,3.009007288859441,0.9451599026298523,81.19999999511718,96.01800003417969 181,3.0288018721800585,0.9100659669685364,81.40999994140626,95.98600013671874 182,3.003125685911912,0.887927452545166,81.27000020019531,96.02800003173829 183,3.008813509574303,0.869329244632721,81.37200002685547,96.01800000732422 184,3.0261400571236243,0.9118497411537171,81.52200002197266,96.00800010986327 185,3.004841850354121,0.9412287125587463,81.54200007080078,95.99000010986327 186,3.013606933447031,0.8910019776153565,81.52200012451172,96.04800006103515 187,2.9876741079183726,0.923979015827179,81.59199994384765,96.07599998046875 188,2.9696753850350013,0.9355228743362427,81.60599996826171,96.01600008544922 189,2.954044461250305,0.9194323031997681,81.5740000756836,96.01999998291015 190,3.001693844795227,0.9340647240066529,81.67800007324219,96.07999998291015 191,2.929094571333665,0.8478024571037293,81.85800018066406,96.16800000976562 192,2.990995086156405,0.916919107837677,81.76200004882813,96.18799995605468 193,2.976861696976882,0.9573249016571045,81.62800002197265,96.06000005859374 194,2.9593937397003174,0.8564995549583435,81.91399994140625,96.27199992919923 195,2.941809947674091,0.8975905869483948,81.79000001953125,96.19000003173828 196,2.976395744543809,0.9395446335983276,81.95199991455078,96.15799998291016 197,2.9350195206128635,0.9216514646148681,81.84800004638672,96.24800013671874 198,2.942404765349168,0.833588249835968,82.15200007568359,96.24200003173829 199,2.9270170743648825,0.8581027887535095,82.06399997070312,96.26400005859375 200,2.9117765701734104,0.8230641770172119,81.91799999267577,96.30999992919922 201,2.9535570328052225,0.932054790763855,82.19600002197265,96.28800008544921 202,2.951778081747202,0.8631351574707031,82.04400009277344,96.26199992919922 203,2.8975991835960975,0.9194772821807862,82.2759999975586,96.3100000341797 204,2.9302108104412374,0.8795070449256897,82.24599999511719,96.35400000732422 205,2.88518776343419,0.8075243322563171,82.11200001953125,96.42400003173829 206,2.8956097731223474,0.9984648355102539,82.22799999267578,96.40400000488282 207,2.893759452379667,0.8589856274986267,82.30600004882812,96.39799992919922 208,2.9317749371895423,0.8957522449111939,82.28000004882813,96.35800003173829 209,2.8650061717400184,0.8618767699813843,82.32400001953125,96.39999992919923 210,2.88412383886484,0.8900039534378051,82.3660000439453,96.38799995605468 211,2.8546677277638364,0.8119068654060364,82.51199999755859,96.43000005615234 212,2.89072257738847,0.8770508949661255,82.59799999511719,96.38600005615234 213,2.890521159538856,0.8636206825447083,82.42000004638672,96.48800003173828 214,2.9142217819507303,0.8873085660552978,82.58000004394532,96.38800000732422 215,2.890397998002859,0.8559163021659851,82.7499999951172,96.45200003417969 216,2.858252213551448,0.8466270779418945,82.59599989257812,96.52399992919922 217,2.8843393417505117,0.9059183154869079,82.61600007080078,96.5360000830078 218,2.8276894367658176,0.8348352497673035,82.66399994140625,96.40600010986329 219,2.8219637045493493,0.8524979174041748,82.88600014892579,96.44799992919921 220,2.8488004757807803,0.877018903388977,82.80199997070312,96.53199995361328 221,2.8460423212784987,0.8800382892227173,82.90000007324218,96.49599998291016 222,2.8185073137283325,0.8374453451538086,82.75200012695312,96.53399998046875 223,2.8086545192278347,0.8324967349433899,82.94800004638672,96.54000008544922 224,2.8704542838610134,0.8967138567352295,82.99800004638672,96.60400000732422 225,2.806610419200017,0.8318240044593811,83.09800009521484,96.58800008544922 226,2.8394719453958364,0.87683580160141,83.16000001953125,96.53800003173828 227,2.8203150309049168,0.8570964310073853,83.18600002197266,96.55399992919922 228,2.834619778853196,0.7959771487045288,83.12600009521485,96.64399992919923 229,2.76179917042072,0.8651750437164306,83.08200004882812,96.58200003173827 230,2.79124588232774,0.8245219149398804,83.05600001953125,96.70399998046875 231,2.8034701347351074,0.8726015359115601,83.24400012207032,96.65399992919922 232,2.747335984156682,0.8783034913063049,83.432,96.67800008544921 233,2.7719013415850124,0.8303304001617432,83.30800009765625,96.63400003417969 234,2.7671942710876465,0.8046093546676636,83.38800022949219,96.72800000732421 235,2.7805707179583035,0.8224680534934997,83.34000001953125,96.63600005859375 236,2.7180897180850687,0.8442720851898193,83.31200017578125,96.66800010986329 237,2.7626740840765147,0.8542504468154907,83.55999991699218,96.74800010986328 238,2.7336452649189877,0.7994218709564209,83.48200022460938,96.75800000732421 239,2.7233040791291456,0.8043592370414734,83.36799997070312,96.75200008544923 240,2.6998801781580997,0.8452410479164123,83.50599999267578,96.73600010986328 241,2.73690388752864,0.8795205630493164,83.41199997070312,96.68199995605468 242,2.749963100139911,0.8602245480918884,83.54400004882812,96.77599995605469 243,2.698979405256418,0.8548383228492736,83.53200005126953,96.78000008544922 244,2.72799957715548,0.7891766788101197,83.61400001953125,96.73800008544922 245,2.7299381494522095,0.7910343458747864,83.68000006835938,96.82400010986328 246,2.701122834132268,0.8022129875183105,83.64600004638672,96.83600010986328 247,2.724263282922598,0.8257709519386291,83.74200012451172,96.81400003173827 248,2.700051261828496,0.8478059650039673,83.89000007324219,96.79600013671875 249,2.7121142332370463,0.8197229961776733,83.67799994140626,96.79600003417968 250,2.6754928001990685,0.807203971862793,83.82799999511718,96.82600000732423 251,2.6727003134213962,0.8816155452919007,83.89399997070312,96.84000021484376 252,2.678276859796964,0.7926013770103455,83.91199999267577,96.87600008544922 253,2.693314579816965,0.8830190794181824,83.71199999511718,96.84400008544922 254,2.7000538385831394,0.7949762279510498,83.96000001953125,96.87799992919922 255,2.6520780416635366,0.8772449176788331,83.88600017578125,96.84400008544922 256,2.6728187249257016,0.803320650062561,83.99600002197266,96.90200008544922 257,2.6415497614787173,0.8764836074638367,83.93999999511719,96.85999992919922 258,2.656065143071688,0.9068516694831849,84.12800015136719,96.85599998046875 259,2.6415039667716393,0.7998114633750916,84.04200009765626,96.88399998046874 260,2.6295432585936327,0.8035087095451355,84.12000022949219,96.86799992919921 261,2.6302762306653538,0.7987516967964172,84.03600001953124,96.87400000732421 262,2.6307426599355845,0.8718977772903442,84.01400010009766,96.88399998046874 263,2.646725838000958,0.8528344987106323,84.08799999267578,96.87399998046875 264,2.6446866530638475,0.8193294952964782,84.15799994140625,96.85400010986328 265,2.6704056446368876,0.8128056451225281,84.08600010009765,96.87400003173828 266,2.6277802265607395,0.8045233806419373,84.13200004638672,96.92999998046875 267,2.627194743890029,0.8470412642097473,84.22599999511719,96.84400008544922 268,2.6251884607168345,0.8471642654037476,84.24400010009765,96.92999990478516 269,2.6308794296704807,0.8462172267532349,84.2559999951172,96.91200005859375 270,2.602825742501479,0.8418390936470032,84.20400001953125,96.93599995361328 271,2.614965172914358,0.8266842465209961,84.29800010009765,96.98600003417968 272,2.6357730627059937,0.7932263405036927,84.19200004882812,96.99600000732421 273,2.6102960935005775,0.7815226528167725,84.29200010009765,96.94800003173827 274,2.581689577836257,0.8857900485992432,84.3839999951172,96.94000000732422 275,2.5989666902101956,0.8270471384239196,84.38600009765625,96.99800000732422 276,2.5846868386635413,0.7947231995010376,84.32600004882812,96.97799998046875 277,2.6171456300295315,0.8640070059013367,84.38799999511718,96.96999995605469 278,2.588435182204613,0.8077503689575195,84.31600012695313,96.95199992919922 279,2.625655614412748,0.8132058056640625,84.34200007324219,96.95599992919922 280,2.561418294906616,0.8367598581314087,84.33400010009765,96.99399998046874 281,2.5751408430246205,0.8749884683609008,84.37200007324219,96.97000003417969 282,2.5804577515675473,0.8383332873916626,84.42400009765625,96.94399998291016 283,2.6351821697675266,0.827440677433014,84.4620000732422,96.93799995361329 284,2.609296101790208,0.802014238834381,84.53200015136719,96.99599998046875 285,2.580980291733375,0.8811971377563477,84.35800017578126,96.90999998046875 286,2.612268191117507,0.7776840033721923,84.43000015136718,96.95199992919922 287,2.5652044919820933,0.8067414920043945,84.48200007324219,96.94799998046875 288,2.5902458887833815,0.86946218957901,84.34200015136719,96.94399992919922 289,2.5916363367667565,0.8006659157180787,84.45599999511718,96.94400005859374 290,2.5842091762102566,0.808074400806427,84.36799999511719,96.94000005859375 291,2.5799223551383386,0.7949452325820923,84.49000018066407,96.98000008544922 292,2.598302180950458,0.8045749850654602,84.56400007324218,96.95600000732422 293,2.5514029356149526,0.8555887586593628,84.42600004882813,96.92000008544922 294,2.570864484860347,0.7697095928192139,84.5159999975586,96.95799992919922 295,2.574796667465797,0.7760678444099426,84.45000004882813,96.97799998046875 296,2.5906164095951962,0.8296054859733581,84.56399999511719,96.93799995605468 297,2.5699487374379086,0.8084378022956848,84.59800007324219,96.95400003417969 298,2.589944940346938,0.8603264539909363,84.50600010009765,96.96199995605468 299,2.551648791019733,0.753419985370636,84.51800007324219,96.97800003173828 300,2.600609458409823,0.9259959189987182,84.56600012451172,96.93399992919922 301,2.557686319717994,0.7893291508483886,84.56200004882812,96.99399992919922 302,2.591786503791809,0.8664678469276428,84.50399997070312,96.98600003173829 303,2.5873728715456448,0.8108873423957824,84.55000010009766,96.95000003173828 304,2.571756307895367,0.8550376346588134,84.50399997070312,96.98800000732422 305,2.5680193167466383,0.8534311754226684,84.46400020507812,96.96399992919922 306,2.549737260891841,0.7815708198738098,84.44999999511718,96.97400003173829 307,2.5860844942239614,0.8156077337646485,84.49400007324219,96.96799990478516 308,2.5601435899734497,0.7796311644172669,84.52599999511719,96.99600003173828 309,2.5672685458109927,0.8047046170425415,84.56000004882813,97.00199992919921 ================================================ FILE: checkpoint/iformer_large/args.yaml ================================================ aa: rand-m9-mstd0.5-inc1 amp: false apex_amp: true aug_repeats: 3 aug_splits: 0 batch_size: 64 bce_loss: false bce_target_thresh: null bn_eps: null bn_momentum: null bn_tf: false channels_last: false checkpoint_hist: 1 class_map: '' clip_grad: 1.0 clip_mode: norm color_jitter: 0.4 cooldown_epochs: 10 crop_pct: null cutmix: 1.0 cutmix_minmax: null data_dir: /dataset/imagenet-raw dataset: '' dataset_download: false decay_epochs: 30.0 decay_rate: 0.1 dist_bn: reduce drop: 0.0 drop_block: null drop_connect: null drop_path: 0.5 embed_dim: 384 epoch_repeats: 0.0 epochs: 300 eval_metric: top1 experiment: iformer_large gp: null hflip: 0.5 img_size: 224 initial_checkpoint: '' input_size: null interpolation: '' jsd_loss: false local_rank: 0 log_interval: 50 log_wandb: false lr: 0.001 lr_cycle_decay: 0.5 lr_cycle_limit: 1 lr_cycle_mul: 1.0 lr_k_decay: 1.0 lr_noise: null lr_noise_pct: 0.67 lr_noise_std: 1.0 mean: null min_lr: 1.0e-05 mixup: 0.8 mixup_mode: batch mixup_off_epoch: 0 mixup_prob: 1.0 mixup_switch_prob: 0.5 model: iformer_large model_ema: false model_ema_decay: 0.9998 model_ema_force_cpu: false momentum: 0.9 native_amp: false no_aug: false no_ddp_bb: false no_prefetcher: true no_resume_opt: false num_classes: null opt: adamw opt_betas: null opt_eps: 1.0e-08 output: checkpoint patience_epochs: 10 pin_mem: false port: '25500' pretrained: false ratio: - 0.75 - 1.3333333333333333 recount: 1 recovery_interval: 0 remode: pixel reprob: 0.25 resplit: false resume: '' save_images: false scale: - 0.08 - 1.0 sched: cosine seed: 42 smoothing: 0.1 split_bn: false start_epoch: null std: null sync_bn: false torchscript: false train_interpolation: random train_split: train tta: 0 use_multi_epochs_loader: false val_split: validation validation_batch_size: null vflip: 0.0 warmup_epochs: 5 warmup_lr: 1.0e-06 weight_decay: 0.05 worker_seeding: all workers: 10 ================================================ FILE: checkpoint/iformer_large/summary.csv ================================================ epoch,train_loss,eval_loss,eval_top1,eval_top5 0,6.910546504534208,6.850864137573242,0.4499999999666214,1.8380000002288819 1,6.643188366523156,5.693355998840332,4.978000012207032,14.922000020751954 2,6.290052909117478,4.7761042728424075,12.611999971923828,30.79799999633789 3,5.945564820216252,3.82945441116333,23.45400005126953,47.374000043945315 4,5.6793032976297235,3.3477501653289794,30.48999998779297,57.0860001171875 5,5.384508353013259,3.0526050999450685,36.21600011962891,63.426000043945315 6,5.250409749838022,2.7507688496017457,41.93400006347656,68.9760001977539 7,5.041767652218159,2.462726919937134,46.78800009277344,73.27800002929688 8,5.04708271760207,2.4434100090026853,47.634000017089846,74.30800002441406 9,4.975198323910053,2.42360043762207,48.37999992431641,75.1100000048828 10,4.829011660355788,2.2661094271850586,50.68400001708984,76.53200008300782 11,4.770912170410156,2.0769049119567873,54.78399994873047,80.32799993896484 12,4.602100574053251,1.9912925200653075,56.97400007324219,81.9780001147461 13,4.475416091772226,1.903586951904297,58.506000017089846,83.21399985107422 14,4.387893658417922,1.738180983314514,61.67000002929687,85.12800008300782 15,4.2983089960538425,1.6956268508148193,62.94399990478516,85.9460001586914 16,4.295326141210703,1.634114217414856,63.83200000732422,86.87800005615235 17,4.199971575003404,1.5565200023078918,65.03799997558593,87.56400002441406 18,4.11814390696012,1.5510349703979491,65.20400005126953,88.04800012451172 19,4.085179814925561,1.5207680701065063,66.44000000732422,88.56600005126953 20,4.049802477543171,1.4801421740722656,66.96200008056641,88.81599999755859 21,4.023964524269104,1.5004107713890076,68.018,89.53000020507812 22,3.941151738166809,1.4183791262245178,68.2960000805664,89.44400018066406 23,3.92056103853079,1.4148408610725403,68.98600007324218,90.14000014892578 24,3.8803509015303392,1.362926469554901,69.38000001953125,90.02999994384766 25,3.885836592087379,1.3796905069923402,69.72199999267578,90.50999999267579 26,3.826949018698472,1.3466479450416564,70.56800001220704,90.72599996826172 27,3.807827426837041,1.294949890460968,70.44800017578125,91.05599999267578 28,3.7755139882747946,1.324147945137024,71.0500000805664,91.0999999975586 29,3.7466984620461097,1.290724811515808,71.66600008789062,91.4300000415039 30,3.7509118685355554,1.2814777055740356,71.91000001708984,91.6260000415039 31,3.737801056641799,1.2778163366699218,72.21599998535156,91.6360001196289 32,3.710351595511803,1.2349367953300476,72.29000001708984,91.89800001464843 33,3.7168827790480394,1.2245775287628173,72.55399996582031,92.18000006347657 34,3.6832274198532104,1.2634735249519349,72.36800007080078,92.01200004882813 35,3.6934918073507457,1.2468857565307616,72.46200011474609,92.12200006591797 36,3.6704082672412577,1.2496089472961425,72.92200000732421,92.40000008789063 37,3.61841528232281,1.1793549740600586,73.54600006835938,92.4979999633789 38,3.6186260260068455,1.176423654499054,73.8760000415039,92.71800011962891 39,3.6086827516555786,1.1873519218063355,73.72600017089843,92.56600004150391 40,3.6681673526763916,1.1816426961135864,73.89600013916015,92.75600006591797 41,3.5902654574467587,1.1312151807022095,73.84999998291016,92.69200009033203 42,3.584832870043241,1.1513134127235412,74.2380001953125,92.85600009033203 43,3.5984942454558153,1.1854713468933105,73.94800005859375,92.85000013916016 44,3.6091530047930203,1.1779234772491456,74.45200006103515,93.05000013916016 45,3.485910278100234,1.124913458557129,74.51200013916015,92.87999999023438 46,3.53954606789809,1.1186642870903014,74.91000016113281,93.23799998291015 47,3.5389673618169932,1.1638443058395387,74.89200000976562,93.29199998779296 48,3.581275013776926,1.1198174129295349,74.83599997558593,93.1000000366211 49,3.550371078344492,1.124014576892853,75.47600002929687,93.30400001220703 50,3.518071541419396,1.116720790939331,75.13999995117187,93.46999998535156 51,3.544528530194209,1.1221498290252685,75.35199995605468,93.39000006591797 52,3.5305112233528724,1.153511095676422,75.22399997558594,93.4659998803711 53,3.512775265253507,1.1154144276809692,75.54600013427735,93.59000001220703 54,3.5236593118080726,1.1202192064666747,75.3940000390625,93.45000014404297 55,3.433111392534696,1.096172775325775,75.87800006347656,93.67000001708985 56,3.4627847304710975,1.1273769705581664,75.95400010986329,93.6540000366211 57,3.4790163223560038,1.0841456203651427,76.21400003173828,93.7280001928711 58,3.469840278992286,1.1170489598274231,75.92600005615235,93.62200009033204 59,3.4525169225839467,1.0933158408927917,76.21000011230468,93.7460000415039 60,3.4901640873688917,1.126006812171936,76.19000008300782,93.78400006347657 61,3.4491439507557797,1.0970030860710145,76.21999998535156,93.81000001220703 62,3.432765245437622,1.102595337715149,76.3400001611328,93.96200000976563 63,3.4307719469070435,1.0557644449996948,76.3160000341797,93.7779999609375 64,3.4119530549416175,1.0703777256584168,76.16200011230468,93.74999993408203 65,3.4012008630312405,1.0802638423919677,76.45600000732422,93.7179998803711 66,3.4363037989689755,1.0589085584640503,76.77399997558594,94.11200008544922 67,3.442072950876676,1.079272590522766,76.80000013183594,94.03599990478516 68,3.4058668430034933,1.0758214369010926,76.81999989990234,94.10399990966796 69,3.420797137113718,1.0494617863845825,76.91000010986328,94.1500000366211 70,3.425681655223553,1.0563462660980225,76.71800008789063,94.0239998803711 71,3.4518049221772413,1.0740118409538268,76.88800006103516,94.09999998535156 72,3.3764912531926083,1.0612143629455566,77.09400005859375,94.22000000976563 73,3.396061952297504,1.0656421076011657,76.84600000488281,94.17600008789063 74,3.377353860781743,1.038946965942383,77.39399997802734,94.1700001928711 75,3.3845853255345273,1.0483537149047852,77.19200008544922,94.14399993408203 76,3.4312566060286303,1.0124181773376464,77.10600003417969,94.2560001147461 77,3.3898003651545596,1.0355482028388978,77.07400000488282,94.21800013916015 78,3.3775642835176907,1.0285578982925414,76.92400008789062,94.15399995849609 79,3.407664785018334,1.04491987159729,77.33000002685547,94.2539999584961 80,3.323443284401527,1.0443046270751952,77.40400016113281,94.26999998535156 81,3.3806505845143247,1.0459676809692382,77.52599987792969,94.2600000366211 82,3.382426720399123,1.042313983669281,77.59600008300781,94.29800001220703 83,3.3965557446846595,1.0520244059181214,77.6939999243164,94.46400005615234 84,3.3593188799344578,1.0567831263923646,77.55600005615234,94.34799995849609 85,3.3830844530692468,1.0939044120025634,77.37399982910156,94.3300001147461 86,3.3928306561249952,1.053267989654541,77.36200018554688,94.33200006347656 87,3.3483442343198337,1.03019209690094,77.79400000976563,94.4760000390625 88,3.400100533778851,1.0476388161849977,77.95800003173828,94.43800006103515 89,3.363451132407555,1.03560297996521,77.67800000488282,94.4600001171875 90,3.356158916766827,1.0289478774261474,77.82800010498048,94.31199995849609 91,3.3031399066631613,1.0243545026016236,78.07000008789062,94.37400001220703 92,3.340137399159945,1.0346087313461303,78.11400003173829,94.68599998291016 93,3.3096237182617188,1.0451400087165832,77.87400008300781,94.53799998291015 94,3.3592547178268433,1.0183623178482055,78.16999997314453,94.54400021728516 95,3.3121973734635572,1.034790303516388,78.08200010742188,94.58400011474609 96,3.285599561838003,1.0314845135879516,78.20800007324219,94.63800003417968 97,3.3367106272624087,1.027085668697357,78.33600010253906,94.68800010986328 98,3.3270413325383115,1.0407975508117675,78.17400010253907,94.70200016357421 99,3.2877965707045336,1.0243351782989503,78.14800010498047,94.69000021728516 100,3.2988253831863403,0.9814645049095154,78.23000010498046,94.48999995605469 101,3.2901770885174093,1.0053197279930115,78.35000014160157,94.68800014404297 102,3.2781461018782396,1.0209910627746581,78.3820000805664,94.74400003662109 103,3.303452546779926,1.0421659322166443,78.30199997802734,94.71599993408203 104,3.291901891048138,1.010951399459839,78.26399998046875,94.74600000976562 105,3.3149656699253964,1.0163331426811217,78.25600008544922,94.77600006103516 106,3.308741569519043,1.0146123032188417,78.522000078125,94.77400008544922 107,3.293911713820237,1.0555380408668518,78.21600012939453,94.7800000341797 108,3.2435123920440674,0.9968828507232667,78.66600003173828,94.75200008544923 109,3.3232491658284116,1.0174081901168823,78.465999921875,94.74800008544922 110,3.2690613820002627,1.0153419143104554,78.45600010742187,94.82400006103515 111,3.2709354712412906,0.9924267137527466,78.67800013183594,94.75800001220703 112,3.2855274860675516,1.0057064008140564,78.87199997314453,94.99600000732421 113,3.230813585794889,1.0080310935401917,78.80600005615234,94.86400006103516 114,3.2722094242389383,1.0076740469360352,78.86400005371094,95.06999998535156 115,3.22453341117272,0.9959960584068298,78.82800008056641,94.95199995605469 116,3.2814402763660135,0.9519607814407348,78.487999921875,94.87799998535156 117,3.2548026855175314,1.0326564323997498,78.628000078125,94.9060000366211 118,3.2524013702686014,1.0131952097511292,78.76800005126952,95.01200003417969 119,3.2414560868189883,1.0325893316078185,78.83600013183593,95.07600000976562 120,3.2432524240933933,0.9700666660690308,79.02600014892577,94.98200008544922 121,3.2066384187111487,1.0066341878509522,79.138000078125,94.99800003417968 122,3.212906388136057,0.9771206850433349,79.23599997558594,95.18600000732422 123,3.2288889243052554,0.9860764531326294,79.150000234375,95.0760000390625 124,3.227351968105023,0.987032350063324,79.35400015625,95.12600000732422 125,3.229693678709177,0.979343092956543,79.24400004882813,95.12000013916015 126,3.239646599842952,0.9877924637985229,79.15600005126953,95.2240001123047 127,3.2091639592097354,0.9348715629386902,79.27000005371094,95.20400013916016 128,3.2042169387523947,1.0053931137657166,79.28400007324218,95.2520000366211 129,3.2251157577221212,1.0165403331756593,79.19800013183594,95.1359999584961 130,3.2341381861613345,0.9704256143569946,79.396,95.2599999584961 131,3.2344076266655555,0.9606916565704345,79.44599995361328,95.26999990478515 132,3.2237201837392955,0.9550872252464294,79.938,95.31600000976563 133,3.2206397239978495,0.9822869607162475,79.31199984375,95.14399998291016 134,3.190979370704064,0.9883818350982666,79.44399992675781,95.17000013916015 135,3.2153615951538086,0.9423513364982605,79.59600007568359,95.31000003662109 136,3.2060717894480777,0.9656000201225281,79.69599997314454,95.30999998291016 137,3.169634094605079,0.9717180125236511,79.84799997558594,95.30599998291015 138,3.1719662959759054,0.9458955500030518,79.91200010498046,95.36600000732422 139,3.1517449158888597,0.9693083749580383,79.68799994628907,95.28599998291016 140,3.175395580438467,0.9482538063621521,79.91800015869141,95.36599995849609 141,3.169075736632714,0.9496471033668518,79.8199999194336,95.31599998291016 142,3.162912818101736,0.9770244847106934,79.74799994873047,95.29200003662109 143,3.149576177963844,0.9529659438323974,80.12799994140624,95.37199998291015 144,3.1874825679338894,0.9519318865776062,79.84200002441406,95.37400008789062 145,3.1350886271550107,0.9400943297576905,79.93600007324218,95.38200006103516 146,3.1708294336612406,0.9612622356414795,79.9659999243164,95.4799999609375 147,3.155010828605065,0.9531323616218567,80.27799997558594,95.48200008544922 148,3.172439987842853,0.9382277800369263,80.16600015625,95.53000011230469 149,3.1274655048663798,0.9533389562416077,80.13000002441406,95.55600008789062 150,3.1443814772825975,0.9371918731307983,80.33399997070312,95.56000006103515 151,3.1267175674438477,0.9546587055587769,80.41200012695313,95.64400006103516 152,3.153881384776189,0.9710652282524109,80.404000078125,95.69400000732422 153,3.0826464799734263,0.9507630409240723,80.43200005126953,95.59800003417969 154,3.15419864654541,0.9281492255783081,80.31200010009766,95.61800016357422 155,3.1137414620472836,0.9433984814453125,80.35400004882813,95.52200008544922 156,3.0899893045425415,0.9258719791603088,80.62000002197266,95.58799998291016 157,3.10144082399515,0.9700759565734863,80.70800010009765,95.55399992919922 158,3.0753163374387302,0.9230545413589477,80.6819999243164,95.64000016357421 159,3.0889107905901394,0.9027757112121582,81.10200007568359,95.81000018798828 160,3.0785955374057474,0.9065625233840943,80.93000007324218,95.75000000732422 161,3.107429366845351,0.9389520249366761,80.83800002197266,95.78800006103516 162,3.0843942990669837,0.9132129728317261,81.03600018066406,95.69800008544922 163,3.0945202020498424,0.9214211973762512,80.93000001953125,95.87799995361328 164,3.054575278208806,0.9130185977935791,80.78400002197266,95.78399992919923 165,3.0874763360390296,0.9020014848709107,80.94199999511719,95.86800013671875 166,3.0821682856633115,0.9138980267524719,81.06599997070313,95.79600006103516 167,3.040875664124122,0.9022171401786804,80.85800012695313,95.83800005859375 168,3.091547810114347,0.9082347438430786,80.89400007324218,95.83400016357422 169,3.100938613598163,0.9023391273498536,81.07800004882813,95.79400008544921 170,3.096641127879803,0.9022777452278137,81.05000004882812,95.85199992919922 171,3.0923791023401113,0.9108617130088806,81.129999921875,95.89400010986328 172,3.052710037965041,0.9479087080192566,81.01200005371093,95.84800010986328 173,3.0646523787425113,0.9275116222953796,81.29999997070313,95.96799998046875 174,3.0505178708296556,0.9056856367492676,81.38399997070313,95.94599992919922 175,3.06196405337407,0.9053907576179504,81.12999994628906,95.80200000732422 176,3.0117578506469727,0.8890290698051453,81.45199989257813,95.95200005859375 177,2.998016329912039,0.9136776152038574,81.152,95.88400008544922 178,3.009301029718839,0.9255991547012329,81.41199994140625,95.99000000732421 179,3.0024658258144674,0.8959691961097718,81.5320000756836,95.97800008544922 180,2.9945607552161584,0.9226432334518433,81.57999989257813,95.97600008544921 181,3.014344673890334,0.8720306507110596,81.6679999975586,95.93400000732422 182,2.9808755196057835,0.8806172090339661,81.3259999194336,96.07399995605469 183,2.9757775435080895,0.8842584958076477,81.55599999511719,96.05399990478516 184,3.0188873547774095,0.8955323510360718,81.74600004394532,95.99000000732421 185,2.9996942465121927,0.8864319196891784,81.68200012451172,96.08600006103515 186,2.9925501071489773,0.8620990559005738,81.90599994140625,96.17200008544921 187,2.9568693821246805,0.8663196732139588,81.82000007080079,96.12199992919922 188,2.9661437639823327,0.8848002459716797,81.72799994140625,96.10600000732421 189,2.9302183389663696,0.898970189743042,81.787999921875,96.08400003417968 190,2.982291790155264,0.898637077960968,81.8679999194336,96.07000003417969 191,2.9291364137942972,0.8914895185279846,81.88999997070313,96.12599998046875 192,2.9678718952032237,0.9045106686401367,82.06000002441407,96.13800005859375 193,2.9405993773387027,0.8911021675300598,82.00600007324219,96.13399998291015 194,2.942032126279978,0.8845843422508239,82.24200002441407,96.20000000976563 195,2.9096350211363573,0.8615652515411377,82.03000015136719,96.26600000732422 196,2.978213383601262,0.8571838988304138,82.2619999951172,96.31399995605469 197,2.915033533022954,0.8975017419433594,82.32400007324219,96.32999990478515 198,2.90700048666734,0.8796381487083436,82.37400012939453,96.22200000732421 199,2.891255121964675,0.8519094377326966,82.1440001513672,96.24399990478516 200,2.8835707902908325,0.8565283406829834,82.1980001538086,96.35399992919922 201,2.922385866825397,0.8883370819091797,82.22599999511719,96.34400010986329 202,2.921412935623756,0.8623068803215027,82.43399994140626,96.26800003417969 203,2.855884689551133,0.8547059134101868,82.58599997070313,96.36800000732421 204,2.8865085656826315,0.8690887733078003,82.58200009765625,96.42799998046875 205,2.8650464919897227,0.8593528522872925,82.65200002197265,96.35600000732421 206,2.8669342352793765,0.8630787907600402,82.83999988769531,96.36600003173828 207,2.8705076346030602,0.8752160062599182,82.48999994384765,96.35800000732422 208,2.901468093578632,0.8615425992584228,82.84400007324219,96.39800000732421 209,2.8473127896969137,0.8606333646774292,82.68400020019531,96.49200000732422 210,2.8600355203335104,0.8499924317932129,82.50199997070312,96.38400005859376 211,2.835801885678218,0.8610882561302186,82.69799999511719,96.39600003173828 212,2.851739296546349,0.8339872376823425,82.96800017578126,96.49599990478515 213,2.8614711669775157,0.8409391230201722,82.8040001220703,96.54600005615234 214,2.8752744381244364,0.8455844373703003,82.92400012207031,96.59000003173828 215,2.8528220103337216,0.830362998752594,83.00600001953126,96.54000000732422 216,2.8457366869999814,0.8478204371833802,82.98000012695313,96.51600010986328 217,2.8469672203063965,0.8468237362861634,83.06399986328125,96.54400010986328 218,2.798969342158391,0.8438536651229859,83.34799994140624,96.55800005859375 219,2.7844803608380833,0.8434881903266906,83.21400012695312,96.59400006103516 220,2.8229461449843187,0.851990555973053,83.23599989257812,96.62600008544922 221,2.8255511980790358,0.8601199535369873,83.14200005126953,96.58200005859375 222,2.769902761165912,0.8383212358665466,83.32399994628906,96.50600018798828 223,2.788533880160405,0.8396528512954712,83.27799996826172,96.50800000732421 224,2.830211648574242,0.8376853758049011,83.44600001953125,96.66000003173828 225,2.7834478525015025,0.8274393226051331,83.33000004394532,96.56200008544921 226,2.8031968611937303,0.83849549451828,83.34400001708984,96.56799992919922 227,2.7822887163895826,0.8502991510391236,83.36599996826172,96.65399990478515 228,2.7865052681702833,0.8371617828369141,83.30399996582031,96.65000018798828 229,2.736697848026569,0.8339061546897888,83.46600001708984,96.67600003173828 230,2.756590568102323,0.8257232703781128,83.60399996826172,96.74800003173829 231,2.754254698753357,0.8320276457595825,83.5720000732422,96.66800003173829 232,2.717001960827754,0.8391983484077453,83.53200017578125,96.62000000732422 233,2.7290770640740027,0.8272213029289246,83.58199996826171,96.76800008544922 234,2.744955145395719,0.8214240999031067,83.49600012695312,96.77000000732421 235,2.7437443091319156,0.8427003208732605,83.48800015136719,96.67799995605469 236,2.6601342146213236,0.8298304134178162,83.58400004882813,96.70799995605469 237,2.7236390939125648,0.8376305979537964,83.87200014892578,96.78199992919922 238,2.6828131584020762,0.8278305959701538,83.7240000415039,96.82199998046875 239,2.686338681441087,0.8361137383460998,83.79000009765625,96.7519999584961 240,2.6599515951596775,0.8282428256988525,83.8899999951172,96.84599998291016 241,2.708692715718196,0.840486863193512,83.86400007324218,96.81600008544922 242,2.706350097289452,0.8143806471061706,83.84199999511719,96.72200008544922 243,2.675070423346299,0.820715518321991,83.93799999511718,96.78400005859375 244,2.67621019253364,0.8190324502944947,83.76800007324219,96.73399990478515 245,2.6871630870378933,0.8306452424812317,83.95400001953125,96.75600003173828 246,2.668807891698984,0.8138127730178834,83.83600001953126,96.82799990478516 247,2.670647089297955,0.8158876292228698,83.91600014892578,96.84599992919922 248,2.6516282008244443,0.8102354573631286,83.95800007080078,96.90400000732421 249,2.6651534667381873,0.8250773019981384,84.03800012207032,96.84999998046875 250,2.634173879256615,0.8186645620155334,84.04400009765625,96.84399998291016 251,2.6318732500076294,0.8051316935539246,84.12400007324219,96.82799995605468 252,2.6179944460208597,0.7980528848648071,84.06599999267578,96.82800003173828 253,2.647362910784208,0.834535809059143,84.21600001708984,96.90199992919922 254,2.6695059812985935,0.7965805632972718,84.12399994140625,96.86800000732421 255,2.6050742131013136,0.8134198690032959,84.21600009765625,96.84599992919922 256,2.628141458217914,0.8186293986320495,84.19600007080078,96.83400003173828 257,2.5968537147228536,0.8014547142791748,84.20799999267578,96.88799992919922 258,2.595403607075031,0.8086359707069397,84.36800004638671,96.89000003173828 259,2.5980473114893985,0.8042134921073913,84.22400004638672,96.88799998046875 260,2.579132290986868,0.8025791395187378,84.32800009765624,96.89399998046875 261,2.578222779127268,0.8010256945419312,84.27399991699218,96.87000000732422 262,2.583704893405621,0.8182663324546814,84.27199994140625,96.88799992919922 263,2.5933315478838406,0.8125042825698853,84.42200009765625,96.96999992919922 264,2.594082392179049,0.8039098625183105,84.48800001953126,96.92399998046875 265,2.590479172193087,0.7950776927566529,84.36800004394532,96.94000003173828 266,2.5677045033528256,0.7946233463287353,84.27600012451173,96.92599998046875 267,2.5831249952316284,0.7977374027442932,84.35800007324218,97.02199998046875 268,2.574748296004075,0.8000576438903808,84.36599989257813,96.96199998046875 269,2.572749523016123,0.8113612873840332,84.43400009765625,96.96200008544922 270,2.5436680500323954,0.8111577160453797,84.49600012207031,96.93799998046875 271,2.5695200791725745,0.8020232275581359,84.50399999511718,96.95399990234375 272,2.5755814955784726,0.8084232340240478,84.37200012207032,96.98400005859375 273,2.5361680342600894,0.8032256897354126,84.48599991455077,96.98399992919921 274,2.5335547832342296,0.8067118916130066,84.49200004394531,96.95800003173828 275,2.5463146613194394,0.8099793620681762,84.49399999267578,96.92399992919921 276,2.5408946642508874,0.8107297235488892,84.56000012451172,96.97399992919922 277,2.5601063875051646,0.79685419921875,84.54400001953125,96.94800000732423 278,2.5401773361059337,0.8168009343910217,84.58600015136719,96.91799992919921 279,2.5693678764196544,0.8146619980621338,84.57600012207031,97.00399992919922 280,2.4974713142101583,0.8080339497184753,84.60600020019531,96.95999992919921 281,2.5239082941642175,0.8044849347305297,84.65400004394532,96.98600008544922 282,2.5418753899060764,0.7914894469451904,84.64600007324219,96.99399998046874 283,2.5818122075154233,0.8061416247177124,84.64800012207031,96.94399998046875 284,2.5547984655086813,0.7970900491142273,84.6160000439453,96.95999992919921 285,2.5307810673346887,0.8120660132408142,84.62400004394532,96.95200008544921 286,2.568904161453247,0.793556639137268,84.68400009765625,96.94599998046876 287,2.5109933614730835,0.7803754375076294,84.68400004638671,96.99200005859375 288,2.5294809433130117,0.8061936453437805,84.63400014892578,97.01599992919922 289,2.5508905740884633,0.8128651645469666,84.67599996582031,96.97600000732422 290,2.5313611947573147,0.7819976863479614,84.75200001953125,97.03600008544922 291,2.5070205376698422,0.8001993773460389,84.75200009765625,97.01799992919922 292,2.5351219268945546,0.8031717999839783,84.68800009765626,96.98799992919922 293,2.480671396622291,0.8163573293685913,84.67800007080078,96.96999992919922 294,2.4951118505918064,0.7894428477478027,84.67400001953125,96.97799992919921 295,2.517219598476703,0.799906842327118,84.6820000732422,97.00399998046875 296,2.517501244178185,0.7853897292900085,84.69600001953125,96.98999992919921 297,2.4902893488223734,0.8003411549377442,84.72799996582032,97.00999998046875 298,2.5424841458980856,0.7879449011230468,84.69800012207031,97.01599998046875 299,2.490955499502329,0.7997499727249145,84.62600017578124,97.00399992919922 300,2.5226229337545543,0.798079895362854,84.67600009765626,96.98999992919921 301,2.4854222994584303,0.8056669466209412,84.73400004394531,96.95399992919921 302,2.51904087800246,0.8224670502853394,84.64000007080078,96.97400000732422 303,2.5248863238554735,0.7937660120773316,84.72000004394532,96.96400010986328 304,2.4830260826991153,0.8070458587837219,84.63000001708984,96.97000010986328 305,2.520942073601943,0.809012079486847,84.58200012207031,96.94400010986328 306,2.4892017474541297,0.795564078617096,84.64000017333984,96.96200010986328 307,2.529263872366685,0.8014956588554383,84.71000012207031,96.97800000732421 308,2.510762306360098,0.7898429228019714,84.71800009765624,96.95399992919921 309,2.516177461697505,0.7976002066612243,84.6880001220703,96.99199998046875 ================================================ FILE: checkpoint/iformer_small/args.yaml ================================================ aa: rand-m9-mstd0.5-inc1 amp: false apex_amp: false aug_repeats: 0 aug_splits: 0 batch_size: 64 bce_loss: false bce_target_thresh: null bn_eps: null bn_momentum: null bn_tf: false channels_last: false checkpoint_hist: 1 class_map: '' clip_grad: null clip_mode: norm color_jitter: 0.4 cooldown_epochs: 10 crop_pct: null cutmix: 1.0 cutmix_minmax: null data_dir: /dataset/imagenet-raw dataset: '' dataset_download: false decay_epochs: 100 decay_rate: 0.1 dist_bn: reduce drop: 0.0 drop_block: null drop_connect: null drop_path: 0.2 embed_dim: 384 epoch_repeats: 0.0 epochs: 300 eval_metric: top1 experiment: iformer_small gp: null hflip: 0.5 img_size: 224 initial_checkpoint: '' input_size: null interpolation: '' jsd_loss: false local_rank: 0 log_interval: 50 log_wandb: false lr: 0.001 lr_cycle_decay: 0.5 lr_cycle_limit: 1 lr_cycle_mul: 1.0 lr_k_decay: 1.0 lr_noise: null lr_noise_pct: 0.67 lr_noise_std: 1.0 mean: null min_lr: 1.0e-06 mixup: 0.8 mixup_mode: batch mixup_off_epoch: 0 mixup_prob: 1.0 mixup_switch_prob: 0.5 model: iformer_small model_ema: false model_ema_decay: 0.9998 model_ema_force_cpu: false momentum: 0.9 native_amp: false no_aug: false no_ddp_bb: false no_prefetcher: true no_resume_opt: false num_classes: null opt: adamw opt_betas: null opt_eps: null output: checkpoint patience_epochs: 10 pin_mem: false port: '25500' pretrained: false ratio: - 0.75 - 1.3333333333333333 recount: 1 recovery_interval: 0 remode: pixel reprob: 0.25 resplit: false resume: '' save_images: false scale: - 0.08 - 1.0 sched: cosine seed: 42 smoothing: 0.1 split_bn: false start_epoch: null std: null sync_bn: false torchscript: false train_interpolation: random train_split: train tta: 0 use_multi_epochs_loader: false val_split: validation validation_batch_size: null vflip: 0.0 warmup_epochs: 5 warmup_lr: 1.0e-06 weight_decay: 0.05 worker_seeding: all workers: 10 ================================================ FILE: checkpoint/iformer_small/summary.csv ================================================ epoch,train_loss,eval_loss,eval_top1,eval_top5 0,6.9094244333413934,6.864997950286865,0.3860000048828125,1.5220000048828124 1,6.569643974304199,5.442198684692383,6.860000015258789,19.424000006103515 2,6.168546676635742,4.498686260910034,16.655999987792967,37.183999958496095 3,5.779889510228084,3.6264168195343016,27.648000013427733,52.987999997558596 4,5.510022530188928,3.047014978942871,35.92400004272461,62.721999880371094 5,5.173868894577026,2.683674906616211,43.09600005615234,70.35399994384765 6,4.986480364432702,2.4179258517074587,49.2040000390625,75.45399994873047 7,4.783603338094858,2.1778022815322875,53.34999995605469,79.12000002197266 8,4.67784386414748,2.046561869277954,56.38800017578125,81.42400006103516 9,4.566337640468891,1.971249479675293,58.90599991210937,83.36400008544922 10,4.354064602118272,1.9051295560455321,60.89799993652344,84.71400011230469 11,4.3395165846898,1.8117209907531737,61.95200006591797,85.39400020751953 12,4.263674002427321,1.763902481956482,63.86200010986328,86.63800010742187 13,4.20427758877094,1.7423424057769776,64.12000009033203,87.0420000830078 14,4.176988821763259,1.55658791595459,65.61000010742187,87.72 15,4.075015709950374,1.509385205974579,66.51999999023438,88.22600004394532 16,4.076420655617347,1.554206342601776,66.93600020263672,88.6279999951172 17,4.026540004290068,1.5549724096870423,67.67199991455078,89.07600001953125 18,3.9580479585207424,1.5148561889839172,67.98800004150391,89.18000014892579 19,3.957311511039734,1.4817134371757508,68.57399999755859,89.70000004638672 20,3.940097460379967,1.4594665077018738,69.18200001708985,89.91600001708984 21,3.9163199479763326,1.3787175881767273,69.98600007080078,90.29000004394531 22,3.838025588255662,1.378886430244446,69.57000009277344,90.28199999511719 23,3.810691558397733,1.3630643543052674,70.48800014648438,90.51600001953125 24,3.808310343669011,1.3218276855659485,70.53600006591797,90.73799999267578 25,3.835831018594595,1.361201537361145,70.79800001220703,90.82400004394532 26,3.7596489832951474,1.3758483277130127,71.0260000366211,91.13600012207031 27,3.7489256125230055,1.384459337501526,71.02999996337891,91.05399994140625 28,3.7584116642291727,1.3283219172668457,71.97800001220703,91.45600009521485 29,3.7523052509014425,1.3312880218887329,71.82800014160156,91.43600006835938 30,3.7243771736438456,1.296002097415924,71.90199990722657,91.47000001464843 31,3.7210132158719578,1.3768975380706787,72.24400001464844,91.61199999267578 32,3.6862072394444394,1.2716120315933228,72.53200000976562,91.7980001171875 33,3.714735608834487,1.2190303552818298,72.40000003417968,91.77000009033203 34,3.69110553081219,1.2783337624168396,72.73400003417969,91.73400006591797 35,3.6636599118892965,1.313292693901062,72.52200013671875,91.90200009033204 36,3.649338933137747,1.215225392036438,72.95600008544922,92.00399999023438 37,3.640485250032865,1.2873099118232727,73.19200000976562,92.10800004394531 38,3.649407689387982,1.2403657006073,73.21400006347656,92.20200014404297 39,3.606833659685575,1.235068807926178,73.37400013671875,92.18999996337891 40,3.7104046161358175,1.2011462058448792,73.44600001220704,92.3440000415039 41,3.6258209668673,1.3457986955070496,73.85400016601562,92.31399993896484 42,3.599777863575862,1.2382625227165223,73.50000000732422,92.20800001464843 43,3.6305395456460805,1.2349728396415711,73.44799985351563,92.32800001464844 44,3.6406014424103956,1.2155342070961,73.8500000415039,92.33399993896484 45,3.5644080822284403,1.2261458265686036,73.79799990722657,92.51599986083984 46,3.581205817369314,1.2117936054229737,74.01800008789063,92.4740000415039 47,3.5979900726905236,1.2034554265403747,74.19799993164062,92.5880000390625 48,3.627637432171748,1.264913904914856,74.18600008544922,92.46200009277344 49,3.580585874043978,1.213218408164978,74.36400011474609,92.62800009521484 50,3.5658984092565684,1.2477872757339477,74.36799991210937,92.71800006591796 51,3.5925065829203677,1.1601819841575622,74.7800000341797,92.90200011962891 52,3.535925791813777,1.2709670811653138,74.69200011230468,92.8500001977539 53,3.57413322191972,1.2347692168426514,74.53600013183593,92.70200006591797 54,3.569970341829153,1.2118336687660218,74.42800001220704,92.91000017333984 55,3.4951596718568068,1.232526140346527,74.52799998046875,92.87799993896485 56,3.528607414318965,1.2083224132537842,74.58200013671875,92.87600006835937 57,3.5129952339025645,1.2047773654174805,74.81600010986328,92.98000016845702 58,3.5265593253649197,1.1868788903999328,75.10600000488282,93.03999998779297 59,3.527732014656067,1.1377841114616394,75.05200003417968,93.12600011962891 60,3.5698564511079054,1.2140665634918213,75.15199995849609,93.22600009033204 61,3.5069606395868154,1.212316800441742,75.02599998046875,93.09600006347657 62,3.5076429568804226,1.207794077129364,75.30800006103516,93.16800016845703 63,3.5095110673170824,1.1751184555625915,75.05800008300781,92.98200006591797 64,3.495274332853464,1.1663758554077148,75.43800009033203,93.19200001220703 65,3.496675436313336,1.156433831691742,75.37000001220703,93.27800001220703 66,3.513992263720586,1.1885098503875733,75.70400010986329,93.41200008789062 67,3.504016555272616,1.1901980477905274,75.6640000341797,93.48599990966797 68,3.5192229656072764,1.2184467774772645,75.40999995361328,93.28399993408203 69,3.4537463738368106,1.144714741153717,75.72999997802734,93.42600001220703 70,3.50569604910337,1.1589157112121582,75.66200005371094,93.31400006591797 71,3.5316452154746423,1.0998509642791747,76.00400008544922,93.43200011962891 72,3.4660944846960215,1.125511461544037,76.02200003173829,93.37199996337891 73,3.455447737987225,1.2478693104553222,75.81600005615235,93.34800016845703 74,3.474152766741239,1.2159589889144897,75.98600014404298,93.4120000390625 75,3.4426536560058594,1.135044396018982,75.85400003173828,93.53000008789063 76,3.5005324253669152,1.1231057829284667,75.93800005371094,93.47800016601562 77,3.4688816253955546,1.1647394517326355,76.06999992919921,93.5140000390625 78,3.4501122419650736,1.0868774487495423,76.11599992675781,93.48000001220703 79,3.4457626984669614,1.1312008828926086,76.07600013671875,93.55400003662109 80,3.3960279593100915,1.0933656254005433,76.20400010986329,93.6460000390625 81,3.4835668068665724,1.1270457012176514,76.2800000830078,93.7440001147461 82,3.4436028645588803,1.126568383769989,76.30000010742188,93.69600006591797 83,3.4565229690991917,1.1777190805244446,76.37800005371093,93.68800011474609 84,3.414357983149015,1.188438509941101,76.37800010986328,93.59000019287109 85,3.459013196138235,1.2144730036735534,76.43800008544922,93.6420000366211 86,3.45191019314986,1.1025224459266663,76.35400005859375,93.5240000390625 87,3.412708933536823,1.1987829148101807,76.59799992675781,93.79000006347657 88,3.466808089843163,1.1002231981658936,76.82800010986328,93.8160000390625 89,3.4394527673721313,1.0912003845787048,76.66200010986329,93.84000009033203 90,3.434341018016522,1.1553971852874756,76.66199998291016,93.88400003662109 91,3.394513542835529,1.0940003709220887,76.6640000830078,93.77599998291015 92,3.434716270520137,1.107185283679962,76.77800005859375,94.02400000976563 93,3.3863757940439076,1.1242798097229003,76.96600002929688,93.85199998291016 94,3.430507412323585,1.1021452717971802,76.88600005371094,93.89400003173829 95,3.381525305601267,1.1088750526237487,77.00400008544922,93.92400016845703 96,3.387681630941538,1.1086897051239013,76.98400005371094,93.8740001147461 97,3.407272118788499,1.0778170557975768,77.01400020996094,94.00999993408203 98,3.4062380607311544,1.0873821159362793,76.81799995361328,93.92000001220703 99,3.3701010025464573,1.1149417749214172,76.86600010986328,94.01000019287109 100,3.3691365076945377,1.1462936076545716,76.86200003173828,93.82600013916016 101,3.375997204046983,1.0753991451454163,76.94400018554687,94.07800009033203 102,3.357974116618817,1.135077677707672,77.24800008300781,94.0820001171875 103,3.3830726513495812,1.1030329714012146,77.17000008056641,94.06600008789063 104,3.367032500413748,1.082635837879181,77.20999989746093,94.11800001220703 105,3.3800272391392636,1.0921122213745118,77.23400008300781,93.99600001220703 106,3.3662016024956336,1.1072591597175598,77.01600000732422,94.1160001147461 107,3.3502725729575524,1.1785977474021911,77.32,94.09800008789063 108,3.340722441673279,1.0760896536827087,77.47400005859375,94.14800014404297 109,3.3812398176926832,1.083845131263733,77.57799998046875,94.2759999584961 110,3.3553919150279117,1.1084872850608827,77.42600009033202,94.15200001464844 111,3.325396336041964,1.1166858360671996,77.54600012939453,94.14599993164063 112,3.349822860497695,1.085412045841217,77.49600010253906,94.20199998291015 113,3.3029517852343044,1.1392902358436585,77.62800005859376,94.40200006103515 114,3.3522927577678976,1.142200709590912,77.652000078125,94.34400016601562 115,3.316668547116793,1.074139283103943,77.68600005371094,94.45799995849609 116,3.326545972090501,1.07850544713974,77.77000008300782,94.21599998291016 117,3.315403158848102,1.145295484161377,77.74600005615234,94.41199998291016 118,3.315896850365859,1.0736533163452149,77.73600005126953,94.33600000732422 119,3.332959147600027,1.0669116265106202,77.61800013916016,94.34399993408203 120,3.342320543069106,1.0954551971435547,78.09600010986328,94.49600006591797 121,3.269335691745465,1.109053556213379,78.05999987304688,94.4940001147461 122,3.299822678932777,1.059656356754303,78.30600013427734,94.56000011474609 123,3.3047666274584255,1.120321944065094,77.90000005371094,94.41400000976563 124,3.318658021780161,1.0599626556968689,78.116000078125,94.53599998291016 125,3.3260184893241296,1.104176430568695,78.12400000732421,94.45600003662109 126,3.327208610681387,1.080815853767395,77.98200008300782,94.48600008789063 127,3.2752062265689554,1.0735882032585145,77.95600008300781,94.48600006591796 128,3.2862806778687697,1.1019770643424989,78.19400015625,94.37000008789063 129,3.300983373935406,1.1177044029426575,77.99000013183594,94.31200008789062 130,3.2768982557150035,1.0396747249031066,78.43000000488281,94.7400000366211 131,3.2978845559633694,1.0482300101852418,78.12200002441406,94.59199998291015 132,3.328818284548246,1.0577856283950806,78.08999995361329,94.65000001220703 133,3.2839138141045203,1.0800661448669433,78.22400012939453,94.58000006103515 134,3.2449550261864295,1.0812386049079894,78.42199994628906,94.67800006103515 135,3.273659210938674,0.9925222877883911,78.83200018066407,94.73800013916015 136,3.279651779394883,1.1514595987319947,78.44199994628906,94.62399998291015 137,3.249285734616793,1.0616181856918334,78.67999995117188,94.70000003662109 138,3.2401852149229784,1.030422007408142,78.55200002441406,94.73399990478515 139,3.2287596739255466,1.11498403093338,78.648,94.83400006103516 140,3.2589325996545644,1.1934773761367798,78.630000078125,94.80000013916016 141,3.235685348510742,1.0107798411178588,78.718,94.85000006103516 142,3.255448185480558,1.0406203583145142,78.82800008300781,94.84600003417968 143,3.2458052726892324,1.1242790790367125,78.75999992675781,94.8079999584961 144,3.243349625514104,1.0674845689964294,78.73600005371094,94.85800008544922 145,3.208817793772771,1.053589370918274,78.48999994873047,94.74800000976562 146,3.247132796507615,1.0643479359817505,79.04200002441407,94.94800005859375 147,3.2469917077284594,1.0336022610664368,79.01399994873047,94.91199998291016 148,3.246920136305002,1.0803122089195252,78.92200005371093,94.84000011474609 149,3.2142061361899743,1.0599778313064576,78.99600010498047,94.92599995849609 150,3.217764240044814,1.06389358127594,79.10400002929687,94.95799995849609 151,3.216148073856647,1.0799826690673828,79.1300000024414,94.83200016357422 152,3.202606833898104,1.000363338546753,79.16200004882812,94.87800006103515 153,3.1680813752687893,1.026247396221161,79.16999989746094,94.98600003417968 154,3.2302153752400327,1.05421796377182,79.19600005371093,94.90400006103516 155,3.2140613060731154,1.034919361228943,79.1000000024414,94.9480001147461 156,3.212250150167025,1.0543945582008363,79.29199989746094,94.99999990478516 157,3.189672048275287,0.9840240316200256,79.35799997558594,95.16799995605469 158,3.1594845056533813,1.0651922596740722,79.29200010498047,94.97399990478516 159,3.169252661558298,1.0438192999267577,79.34,95.15399995605469 160,3.1778015265097985,0.9776468455314636,79.22999994628906,95.04400008544921 161,3.184331774711609,0.9410803900527954,79.59000010253907,95.23599998291016 162,3.172625651726356,1.1021912324523926,79.44200005126953,95.08000006103515 163,3.172793911053584,1.006711642932892,79.650000078125,95.21000003417969 164,3.1505272480157704,1.0151854071998596,79.25200004882812,95.04000016357422 165,3.2113893582270694,1.002980606880188,79.53400010742187,95.18599995605469 166,3.203436310474689,1.069907742919922,79.4719999243164,95.11200008544922 167,3.1434664726257324,0.9765472800827026,79.68800013183593,95.34800003173828 168,3.2033794017938466,0.954885217037201,79.74200010253907,95.34199995605469 169,3.157770037651062,1.015079356956482,79.69200018066407,95.14799992919922 170,3.1801951298346887,1.0224423956489563,79.84799997558594,95.30400006103515 171,3.183558913377615,1.0809200729370116,79.80000002929688,95.23600000732422 172,3.166443494650034,0.970890617313385,79.75400015625,95.26800005859376 173,3.1588093500870924,0.9810822336959839,79.93399997558593,95.40400006103516 174,3.1534451246261597,0.9683314956474304,80.04599997314453,95.36200006103516 175,3.176486510496873,0.9815351949119567,79.97000010498047,95.32399995849609 176,3.133654860349802,1.033868189277649,80.04200004882813,95.46000008300781 177,3.0971615681281457,1.057138032245636,79.988,95.32999988037109 178,3.0857103237738976,1.0150846409988403,80.01399987060547,95.31000024169921 179,3.1099898723455577,0.9767458859634399,80.17400010253907,95.46799990478516 180,3.0863860020270715,0.9960832705497742,80.35000010253906,95.4619998803711 181,3.1192220632846537,0.995442993927002,80.112,95.46599992919921 182,3.1065458059310913,0.9503561297225952,80.30000013183594,95.50400008789063 183,3.084474407709562,0.9327744506835938,80.504,95.55199990478516 184,3.1169589299422045,0.987274864768982,80.31800020507812,95.52000008544921 185,3.105303324185885,0.9985076886749268,80.57799994628907,95.55799992919921 186,3.086181319676913,0.9722424831962585,80.55400010498047,95.54000000732422 187,3.0753506880540113,1.005631930141449,80.45999997070312,95.51000008544922 188,3.084871640572181,1.0012442365074157,80.48399994628906,95.65999995849609 189,3.0642203917870154,0.9958682190895081,80.54799996582031,95.58800003173828 190,3.115300545325646,0.9633640588378907,80.41400002685548,95.60000003417969 191,3.0352935699316173,0.9269049572944641,80.61600001953126,95.64400008544922 192,3.096716523170471,0.9777986540985107,80.72600010742188,95.65000000976562 193,3.0447621712317834,1.020225089454651,80.82200004882813,95.61000000732422 194,3.0424715830729556,0.9350263250541687,80.91199994628906,95.66599992919922 195,3.04462403517503,0.9920734120368957,80.93400012695312,95.65400008544921 196,3.0816985093630276,0.9962930574035644,80.73600002441407,95.57600008544922 197,3.043275053684528,1.0144961289787293,80.85000002441406,95.69400008544922 198,3.0301405558219323,0.921901376285553,80.81000007080078,95.72999992919922 199,3.021504677259005,0.9390514798736572,80.93599991699219,95.72799998046875 200,3.0032437581282396,0.9052555331039429,80.92400002441406,95.72999990478516 201,3.0524399005449734,0.9862956017303467,80.98800002441406,95.70599992919922 202,3.04898725106166,0.9408524942779541,80.8440000756836,95.75799992919922 203,2.9965205651063185,0.9691278018569947,81.11999997070312,95.80000010986328 204,3.0177720876840444,0.9536201685333252,81.08200010253907,95.85800008544922 205,3.0006031714952908,0.8881150724983216,81.10200009765624,95.74999992919922 206,3.0060267448425293,1.0253046160316468,81.081999921875,95.70000010986328 207,2.9632287667347836,0.9307623635292053,81.20199997070313,95.88400010986328 208,3.014235395651597,0.980958409614563,81.24000006835938,95.85200008300781 209,2.9804076965038595,0.9197839248657227,81.26800004638672,95.75000016357421 210,3.0220154248751125,0.9738036029434204,81.32600002441406,95.93999995605469 211,2.9515054042522726,0.8529443983078003,81.53800001953125,95.91200010986329 212,3.0006972001149106,0.9346630925750733,81.41200010253907,95.92800008544921 213,2.993186171238239,0.9045049831581116,81.31800002197265,95.89400005859375 214,3.0099153335277853,0.9540823571968079,81.43400012207032,95.90000008544922 215,2.9962432842988234,0.961731686630249,81.57800007568359,95.9020000830078 216,2.964731344809899,0.9199039099502564,81.68000009765625,95.98400018798829 217,2.976665368446937,0.9713804598045349,81.63999996826172,95.97400010986328 218,2.938941423709576,0.9095170165824891,81.77400017822265,95.99000010986327 219,2.9089973247968235,0.9290500170707703,81.55399999511718,96.03600003173828 220,2.972458839416504,0.9646587027359009,81.6720000756836,95.95600000488281 221,2.9590316002185526,0.9856760236167907,81.76999997314454,96.00800013671875 222,2.924618986936716,0.8967197008514405,81.8499999975586,95.98600000732422 223,2.9153974881538978,0.9025303680419922,81.80200012695313,96.00800008544923 224,2.9513279658097487,0.9642524267578125,81.81600002441407,95.99200000732422 225,2.925851684350234,0.9187241349601746,81.70800007324219,96.08600005859375 226,2.944014842693622,0.9485834597969055,81.83400004882813,96.16800005859375 227,2.9328464544736423,0.9324082082939148,81.9540000756836,96.17800000732421 228,2.925115887935345,0.8715889813232421,82.11000005371093,96.16800003173829 229,2.882812270751366,0.9478203728675842,81.99599999511719,96.19000018798828 230,2.893248823972849,0.8939307011795044,82.06799997070313,96.13000008544923 231,2.909112334251404,0.9412263136482238,82.08400004882813,96.15599998046875 232,2.8751327991485596,0.9333421780395508,82.09000010253907,96.07600006103516 233,2.8806763612307034,0.8944049692344666,82.09599999511718,96.09800000732422 234,2.8830181176845846,0.8789783415031434,82.22800002197266,96.13400000732422 235,2.883752153469966,0.9032074878883362,82.192,96.12400000732421 236,2.8425116630700917,0.9093195487785339,82.38400007568359,96.11000010986328 237,2.8728231375034037,0.9299646599769592,82.35999999511719,96.20400003173827 238,2.854443614299481,0.8940147971343995,82.43400007324219,96.19600005859375 239,2.8340333149983334,0.8646629189872742,82.15599997070312,96.22799992919921 240,2.834600338569054,0.9099975925445557,82.34599997070312,96.26400000732421 241,2.8750672890589786,0.9240939585876465,82.35999994628907,96.22399990478516 242,2.881342814518855,0.9277827157592774,82.29400004882812,96.18600000732422 243,2.800852289566627,0.9207671698379517,82.42399997070312,96.29199995605468 244,2.826339785869305,0.8509364444160461,82.48799994140624,96.32199992919922 245,2.832765579223633,0.8535477470588684,82.51800002197265,96.30199992919921 246,2.835671140597417,0.8540577458381653,82.61000007324219,96.34400010986329 247,2.835276411129878,0.88917942527771,82.58200007324218,96.36600016357421 248,2.8517292187764096,0.8996915315628051,82.60200004638672,96.34800005859375 249,2.83868577847114,0.8852077262115479,82.71800018066406,96.34200000732422 250,2.795175653237563,0.868073853187561,82.7640000756836,96.36400003417968 251,2.8127903846594005,0.9428228170013427,82.69600015136719,96.34999992919921 252,2.7924872911893406,0.8700160450935364,82.73600002441407,96.37400000732421 253,2.8241086097864003,0.9485920486831665,82.66599991699219,96.35400008544921 254,2.820255627998939,0.8666281459808349,82.63400012207032,96.34800013671875 255,2.767980071214529,0.9426617022705078,82.69000001953125,96.39400010986328 256,2.7815559827364407,0.8850377077674866,82.79000001953125,96.40200008544922 257,2.7684466472038856,0.9531537490272523,82.81200012695312,96.39999998291016 258,2.776091988270099,0.9598657544517517,82.78800007324219,96.41400000732422 259,2.7829883373700657,0.8710285722160339,82.94200017578125,96.40200010986328 260,2.7462159945414615,0.8686052158546448,82.83200002197266,96.44399998046875 261,2.745678883332473,0.8464594898796082,83.03600012451172,96.44399998046875 262,2.746650503231929,0.9291099748039245,82.88800014892578,96.4680001123047 263,2.7892674849583554,0.9198091730690002,82.92000009765626,96.48000010986328 264,2.7777191950724673,0.876633308506012,83.00000009765625,96.48200000732422 265,2.790822093303387,0.8607185919570923,82.98600007324218,96.47999995605468 266,2.7587546568650465,0.8637187901306153,83.12800012207032,96.47200010986329 267,2.7561703461867113,0.9178548712730408,83.01000004394531,96.47800000732421 268,2.753976033284114,0.9308101778793335,83.02200007324218,96.47800005859375 269,2.7478087773689857,0.905781350402832,83.07399999511719,96.48400005859375 270,2.7359400345728946,0.8987113190078735,83.10199999511718,96.49800000732422 271,2.7687891721725464,0.8974916218185425,83.10600001953125,96.47800005859375 272,2.7437755236258874,0.8698487334251404,83.1459999951172,96.53000010986328 273,2.740336509851309,0.8323400824928283,83.19800015136718,96.50200005859375 274,2.7227883522327128,0.9540283364486695,83.18800015136719,96.49400000732422 275,2.7484019077741184,0.8897652346038818,83.22799999511719,96.52200005859375 276,2.7232350294406595,0.8671867720222474,83.13600015136718,96.55000005859375 277,2.7706412352048435,0.914162262840271,83.18800010009765,96.52600005859375 278,2.7158928284278283,0.8688316952896118,83.13200010009766,96.55800000732422 279,2.757467407446641,0.8733858386993408,83.24800004638672,96.50800005859375 280,2.680275009228633,0.8916000234222412,83.22200002197266,96.53800005859375 281,2.7301262342012844,0.9303448626327515,83.25600007324219,96.54400005859375 282,2.7311555605668287,0.9017908967781066,83.27000007324219,96.53000005859376 283,2.778689439480121,0.8907459851264954,83.33400012451172,96.55800005859375 284,2.730280564381526,0.8617753137397766,83.29599994384766,96.55600000732422 285,2.6996115996287418,0.9456994646072387,83.32199994384766,96.54800000732422 286,2.7530321524693417,0.8394892182922363,83.29600002197266,96.53600000732422 287,2.698102043225215,0.871575404548645,83.29799999511718,96.52800000732422 288,2.7162117774669943,0.9279066291618348,83.29799994384766,96.51800000732422 289,2.731216476513789,0.870179077796936,83.27600007324219,96.53200005859375 290,2.7172455145762515,0.8693408039665222,83.32000009765625,96.56200000732422 291,2.7217825100972104,0.862339599533081,83.32000002197266,96.51200000732422 292,2.738798902584956,0.8762007956314087,83.29199994384766,96.54600000732422 293,2.6903779139885535,0.9252629530525207,83.30799994384766,96.55800000732422 294,2.7017551843936625,0.8361762863922119,83.3040000732422,96.57200000732422 295,2.7112444455807028,0.8382016857719421,83.30800012451172,96.52600000732421 296,2.7227955231299767,0.8906204897880554,83.32800012451172,96.53000000732422 297,2.702854578311627,0.8682446762657166,83.31999994384766,96.54200000732422 298,2.7227264734414907,0.9243216412544251,83.33200002197266,96.49400000732422 299,2.6936778105222263,0.813800574092865,83.33000002197265,96.56000000732422 300,2.7342205964601956,0.9972933919715882,83.30199999511719,96.50200000732421 301,2.696374086233286,0.8523485798072815,83.36800012451172,96.52000000732421 302,2.732568465746366,0.9299849660491943,83.33600004638672,96.53600000732422 303,2.7426971655625563,0.8711758170700074,83.34400007324219,96.53000000732422 304,2.7026734902308536,0.9297433455467224,83.31600012451172,96.51400000732421 305,2.7032283636239858,0.9165915111351013,83.37599994384766,96.53600000732422 306,2.7004379217441263,0.8460414072036743,83.38199999511718,96.55800000732422 307,2.7239219958965597,0.8808569229316712,83.37600012451172,96.53800000732421 308,2.6918300573642435,0.8398378726959228,83.37200002197265,96.55200000732422 309,2.699970080302312,0.8693196088981628,83.35399994384765,96.55000000732421 ================================================ FILE: checkpoint_384/iformer_base_384/args.yaml ================================================ aa: rand-m9-mstd0.5-inc1 amp: false apex_amp: true aug_repeats: 3 aug_splits: 0 batch_size: 32 bce_loss: false bce_target_thresh: null bn_eps: null bn_momentum: null bn_tf: false channels_last: false checkpoint_hist: 1 class_map: '' clip_grad: 1.0 clip_mode: norm color_jitter: 0.4 cooldown_epochs: 10 crop_pct: null cutmix: 0.1 cutmix_minmax: null data_dir: /dataset/imagenet-raw dataset: '' dataset_download: false decay_epochs: 30.0 decay_rate: 0.1 dist_bn: reduce drop: 0.0 drop_block: null drop_connect: null drop_path: 0.5 embed_dim: 384 epoch_repeats: 0.0 epochs: 20 eval_metric: top1 experiment: iformer_base_384 gp: null hflip: 0.5 img_size: 384 initial_checkpoint: checkpoint/iformer_base/model_best.pth.tar input_size: null interpolation: '' jsd_loss: false local_rank: 0 log_interval: 50 log_wandb: false lr: 5.0e-06 lr_cycle_decay: 0.5 lr_cycle_limit: 1 lr_cycle_mul: 1.0 lr_k_decay: 1.0 lr_noise: null lr_noise_pct: 0.67 lr_noise_std: 1.0 mean: null min_lr: 5.0e-07 mixup: 0.1 mixup_mode: batch mixup_off_epoch: 0 mixup_prob: 1.0 mixup_switch_prob: 0.5 model: iformer_base_384 model_ema: false model_ema_decay: 0.9998 model_ema_force_cpu: false momentum: 0.9 native_amp: false no_aug: false no_ddp_bb: false no_prefetcher: true no_resume_opt: false num_classes: null opt: adamw opt_betas: null opt_eps: 1.0e-08 output: checkpoint_384 patience_epochs: 10 pin_mem: false port: '25500' pretrained: false ratio: - 0.75 - 1.3333333333333333 recount: 1 recovery_interval: 0 remode: pixel reprob: 0.25 resplit: false resume: '' save_images: false scale: - 0.08 - 1.0 sched: cosine seed: 42 smoothing: 0.1 split_bn: false start_epoch: null std: null sync_bn: false torchscript: false train_interpolation: random train_split: train tta: 0 use_multi_epochs_loader: false val_split: validation validation_batch_size: null vflip: 0.0 warmup_epochs: 0 warmup_lr: 2.0e-08 weight_decay: 1.0e-08 worker_seeding: all workers: 10 ================================================ FILE: checkpoint_384/iformer_base_384/summary.csv ================================================ epoch,train_loss,eval_loss,eval_top1,eval_top5 0,2.162344753742218,0.6423477774047851,85.416,97.554 1,2.14225176152061,0.6349599724960328,85.544,97.552 2,2.1576704359522054,0.63505445728302,85.514,97.558 3,2.151329924078549,0.6296176392364502,85.482,97.552 4,2.16335995758281,0.6267058710479736,85.574,97.572 5,2.123039128733616,0.6339057161712647,85.602,97.582 6,2.1556805185243193,0.6313015496826172,85.622,97.6 7,2.1183225535878947,0.6357844618606567,85.594,97.558 8,2.143087193077686,0.6312218860626221,85.548,97.582 9,2.0986259544596955,0.625974910697937,85.654,97.574 10,2.111677838306801,0.6280780986785889,85.61,97.62 11,2.139703644256966,0.6327868656158447,85.698,97.598 12,2.1040731925590364,0.6322624744033813,85.684,97.592 13,2.1542149863991082,0.6258501932525635,85.684,97.594 14,2.1099117760564767,0.6318398251342774,85.672,97.57 15,2.1123981055091408,0.6286558788299561,85.654,97.628 16,2.1304921811702204,0.6399081578063965,85.676,97.594 17,2.138498722338209,0.6392654892730713,85.702,97.604 18,2.1148997580303863,0.6415886449813842,85.67,97.604 19,2.1841310519798127,0.6299422284317017,85.648,97.612 20,2.114746145173615,0.6347110688018799,85.674,97.604 21,2.141511425083759,0.6325080072784424,85.68,97.604 22,2.1408919329736746,0.6330610972595215,85.696,97.61 23,2.14104422751595,0.6302861064147949,85.702,97.604 24,2.1165792801800896,0.6278342678070068,85.69,97.596 25,2.1689459833444333,0.6362991049194336,85.706,97.618 26,2.106649726044898,0.6317873718261718,85.69,97.612 27,2.130349175602782,0.6273881248474121,85.698,97.602 28,2.101420769504472,0.629305263748169,85.692,97.6 29,2.1075617484017912,0.6330888885116577,85.67,97.614 ================================================ FILE: checkpoint_384/iformer_large_384/args.yaml ================================================ aa: rand-m9-mstd0.5-inc1 amp: false apex_amp: true aug_repeats: 3 aug_splits: 0 batch_size: 32 bce_loss: false bce_target_thresh: null bn_eps: null bn_momentum: null bn_tf: false channels_last: false checkpoint_hist: 1 class_map: '' clip_grad: 1.0 clip_mode: norm color_jitter: 0.4 cooldown_epochs: 10 crop_pct: null cutmix: 0.1 cutmix_minmax: null data_dir: /dataset/imagenet-raw dataset: '' dataset_download: false decay_epochs: 30.0 decay_rate: 0.1 dist_bn: reduce drop: 0.0 drop_block: null drop_connect: null drop_path: 0.6 embed_dim: 384 epoch_repeats: 0.0 epochs: 20 eval_metric: top1 experiment: iformer_large_384 gp: null hflip: 0.5 img_size: 384 initial_checkpoint: checkpoint/iformer_large/model_best.pth.tar input_size: null interpolation: '' jsd_loss: false local_rank: 0 log_interval: 50 log_wandb: false lr: 1.0e-05 lr_cycle_decay: 0.5 lr_cycle_limit: 1 lr_cycle_mul: 1.0 lr_k_decay: 1.0 lr_noise: null lr_noise_pct: 0.67 lr_noise_std: 1.0 mean: null min_lr: 1.0e-06 mixup: 0.1 mixup_mode: batch mixup_off_epoch: 0 mixup_prob: 1.0 mixup_switch_prob: 0.5 model: iformer_large_384 model_ema: false model_ema_decay: 0.9998 model_ema_force_cpu: false momentum: 0.9 native_amp: false no_aug: false no_ddp_bb: false no_prefetcher: true no_resume_opt: false num_classes: null opt: adamw opt_betas: null opt_eps: 1.0e-08 output: checkpoint_384 patience_epochs: 10 pin_mem: false port: '25500' pretrained: false ratio: - 0.75 - 1.3333333333333333 recount: 1 recovery_interval: 0 remode: pixel reprob: 0.25 resplit: false resume: '' save_images: false scale: - 0.08 - 1.0 sched: cosine seed: 42 smoothing: 0.1 split_bn: false start_epoch: null std: null sync_bn: false torchscript: false train_interpolation: random train_split: train tta: 0 use_multi_epochs_loader: false val_split: validation validation_batch_size: null vflip: 0.0 warmup_epochs: 0 warmup_lr: 2.0e-08 weight_decay: 1.0e-08 worker_seeding: all workers: 10 ================================================ FILE: checkpoint_384/iformer_large_384/summary.csv ================================================ epoch,train_loss,eval_loss,eval_top1,eval_top5 0,2.1128095227938433,0.6339371264839172,85.65200001953124,97.5399999609375 1,2.112883148285059,0.6371526368713379,85.67800004394532,97.5179999609375 2,2.081270286670098,0.6324703786468506,85.65800001953124,97.52000003417969 3,2.0541468331447015,0.6338681053352356,85.69799999511719,97.5360000341797 4,2.1174410466964426,0.6423702626800537,85.70999999267578,97.52200000976562 5,2.06852941100414,0.6357672811508178,85.6640000390625,97.57600003417969 6,2.0652733536866994,0.6328511357688904,85.69000004394532,97.52999998535157 7,2.090355485677719,0.6328032349014282,85.73000006835937,97.57000001220703 8,2.0464458832373986,0.6333639372062683,85.70000001708985,97.5600000341797 9,2.1092736858588,0.6316976878166198,85.66799996582031,97.57000000976562 10,2.0740141524718356,0.6331033121871948,85.74800004394531,97.51800000976563 11,2.0925170343655806,0.6338267395019531,85.76999999511719,97.54000000976562 12,2.0762338707080255,0.6300550039291382,85.79600004394531,97.58400000976563 13,2.0655095875263214,0.6340151790809632,85.72400001953125,97.56800000976563 14,2.0131444334983826,0.6341021949386597,85.77399999023437,97.56400000976562 15,2.045399464093722,0.6350232722473145,85.78200004394532,97.56200000976563 16,2.0737479076935696,0.6348873478507996,85.78999999023438,97.56200000976563 17,2.0775972329653225,0.6329525914382934,85.7680000439453,97.55200000976562 18,2.0284574719575734,0.6307329417991638,85.77000004394532,97.56600000976563 19,2.0756700291083408,0.6293472344970703,85.79199996582031,97.56600000976563 20,2.0585412039206576,0.6404127075767517,85.81800001464843,97.56200000976563 21,2.107305421279027,0.6296587294960022,85.83400001708985,97.56400000976562 22,2.0651893615722656,0.6286195637512207,85.81999999023438,97.55600000976563 23,2.0812353377158823,0.6351058733177185,85.80600004394532,97.56200000976563 24,2.030002793440452,0.630206987991333,85.84400001708984,97.55400000976563 25,2.037764448385972,0.6358882413291931,85.80599999267578,97.56200000976563 26,2.0813280573258033,0.6310837986183166,85.77999999267578,97.55800000976562 27,2.06927875830577,0.6295442185211182,85.83399999267579,97.56600000976563 28,2.0359999583317685,0.6304517384147644,85.80800001708984,97.55800000976562 29,2.0422337972200832,0.6342641407775879,85.80799999267578,97.53600000976563 ================================================ FILE: checkpoint_384/iformer_small_384/args.yaml ================================================ aa: rand-m9-mstd0.5-inc1 amp: false apex_amp: false aug_repeats: 0 aug_splits: 0 batch_size: 32 bce_loss: false bce_target_thresh: null bn_eps: null bn_momentum: null bn_tf: false channels_last: false checkpoint_hist: 1 class_map: '' clip_grad: 1.0 clip_mode: norm color_jitter: 0.4 cooldown_epochs: 10 crop_pct: null cutmix: 0.1 cutmix_minmax: null data_dir: /dataset/imagenet-raw dataset: '' dataset_download: false decay_epochs: 100 decay_rate: 0.1 dist_bn: reduce drop: 0.0 drop_block: null drop_connect: null drop_path: 0.3 embed_dim: 384 epoch_repeats: 0.0 epochs: 20 eval_metric: top1 experiment: iformer_small_384 gp: null hflip: 0.5 img_size: 384 initial_checkpoint: checkpoint/iformer_small/iformer_small_checkpoint.pth input_size: null interpolation: '' jsd_loss: false local_rank: 0 log_interval: 50 log_wandb: false lr: 1.0e-05 lr_cycle_decay: 0.5 lr_cycle_limit: 1 lr_cycle_mul: 1.0 lr_k_decay: 1.0 lr_noise: null lr_noise_pct: 0.67 lr_noise_std: 1.0 mean: null min_lr: 1.0e-06 mixup: 0.1 mixup_mode: batch mixup_off_epoch: 0 mixup_prob: 1.0 mixup_switch_prob: 0.5 model: iformer_small_384 model_ema: false model_ema_decay: 0.9998 model_ema_force_cpu: false momentum: 0.9 native_amp: false no_aug: false no_ddp_bb: false no_prefetcher: true no_resume_opt: false num_classes: null opt: adamw opt_betas: null opt_eps: null output: checkpoint_384 patience_epochs: 10 pin_mem: false port: '25500' pretrained: false ratio: - 0.75 - 1.3333333333333333 recount: 1 recovery_interval: 0 remode: pixel reprob: 0.25 resplit: false resume: '' save_images: false scale: - 0.08 - 1.0 sched: cosine seed: 42 smoothing: 0.1 split_bn: false start_epoch: null std: null sync_bn: false torchscript: false train_interpolation: random train_split: train tta: 0 use_multi_epochs_loader: false val_split: validation validation_batch_size: null vflip: 0.0 warmup_epochs: 0 warmup_lr: 2.0e-08 weight_decay: 1.0e-08 worker_seeding: all workers: 10 ================================================ FILE: checkpoint_384/iformer_small_384/summary.csv ================================================ epoch,train_loss,eval_loss,eval_top1,eval_top5 0,2.3478839855927687,0.6923815357589722,84.32999999023437,97.18999998535156 1,2.3396820655235877,0.6849766717720032,84.33399996582031,97.14000003417969 2,2.2849164650990414,0.7006625423812867,84.32999998779297,97.19599995605469 3,2.2638514316998997,0.6839149640274048,84.42799999023437,97.21000003417969 4,2.3003986065204325,0.6951746271324157,84.51999996582032,97.20600003417968 5,2.265346252001249,0.6716998846244812,84.4420000390625,97.20200003417969 6,2.2564402039234457,0.677104370727539,84.4259999633789,97.22399998535157 7,2.3008200640861807,0.6659402999305725,84.46000001464844,97.27000000976562 8,2.253789108533126,0.6801115174865723,84.54399996582032,97.21400003417969 9,2.3075229708965006,0.6605272066688538,84.46800001464844,97.1820000341797 10,2.284832917726957,0.6837751863861083,84.48199999023437,97.19200000976562 11,2.3022379279136658,0.6809691307640076,84.53400001464844,97.21200000976563 12,2.2808263347699094,0.6799604147911071,84.47399999267579,97.25200003417969 13,2.2778979677420397,0.6703613100624084,84.52200001953125,97.21200000976563 14,2.220526851140536,0.6813768936157226,84.49400001953126,97.22800003417969 15,2.236867505770463,0.6821543011665344,84.56800001953125,97.22800003417969 16,2.2702108713296743,0.6934137293052673,84.51199999023437,97.22400003417968 17,2.2874821562033434,0.7100320655632019,84.55599999023437,97.21800003417968 18,2.2137838762540083,0.6670947971534729,84.56599996582031,97.20400000976562 19,2.2960855869146495,0.673701826992035,84.54999999023437,97.22600003417969 20,2.260560466692998,0.6704543328857422,84.62000001464844,97.22000000976563 21,2.3083780178656945,0.6758274334907531,84.5859999658203,97.21600000976562 22,2.276798422519977,0.6762671936225891,84.54399996582032,97.19600000976563 23,2.2856121980226956,0.6660497138595581,84.61199996582032,97.17600000976563 24,2.2410703026331387,0.68424802734375,84.5979999658203,97.24800000976562 25,2.248056019728,0.65673936378479,84.59199999023437,97.23400000976562 26,2.2597031868421116,0.6829823554992676,84.56999999023438,97.21800003417968 27,2.2474928085620585,0.6944181496810913,84.5619999658203,97.23200003417969 28,2.255860697764617,0.6809109293937683,84.54399996582032,97.24600000976562 29,2.2558590265420766,0.6986223343276977,84.56799996337891,97.25800003417969 ================================================ FILE: fine-tune.py ================================================ #!/usr/bin/env python3 """ ImageNet Training Script This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet training results with some of the latest networks and training techniques. It favours canonical PyTorch and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. This script was started from an early version of the PyTorch ImageNet example (https://github.com/pytorch/examples/tree/master/imagenet) NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples (https://github.com/NVIDIA/apex/tree/master/examples/imagenet) Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse import time # from types import _KT import yaml import os import logging from collections import OrderedDict from contextlib import suppress from datetime import datetime import torch import torch.nn as nn import torch.nn.functional as F import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import * from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler from fvcore.nn import FlopCountAnalysis from fvcore.nn import flop_count_table try: from apex import amp from apex.parallel import DistributedDataParallel as ApexDDP from apex.parallel import convert_syncbn_model has_apex = True except ImportError: has_apex = False has_native_amp = False try: if getattr(torch.cuda.amp, 'autocast') is not None: has_native_amp = True except AttributeError: pass try: import wandb has_wandb = True except ImportError: has_wandb = False torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') # The first arg parser parses out only the --config argument, this argument is used to # load a yaml file containing key-values that override the defaults for the main parser below config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', help='YAML config file specifying default arguments') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='path to dataset') parser.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') parser.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') parser.add_argument('--val-split', metavar='NAME', default='validation', help='dataset validation split (default: validation)') parser.add_argument('--dataset-download', action='store_true', default=False, help='Allow download of dataset for torch/ and tfds/ datasets that support it.') parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') # Model parameters parser.add_argument('--model', default='deit_small_patch16_224', type=str, metavar='MODEL', help='Name of model to train (default: "resnet50"') parser.add_argument('--pretrained', action='store_true', default=False, help='Start with pretrained version of specified network (if avail)') parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', help='Initialize model from this checkpoint (default: none)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Resume full model and optimizer state from checkpoint (default: none)') parser.add_argument('--no-resume-opt', action='store_true', default=False, help='prevent resume of optimizer state when resuming model') parser.add_argument('--num-classes', type=int, default=None, metavar='N', help='number of label classes (Model default if None)') parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', help='Override std deviation of of dataset') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', help='validation batch size override (default: None)') # Optimizer parameters parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: None, use opt default)') parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='Optimizer momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=2e-5, help='weight decay (default: 2e-5)') parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--clip-mode', type=str, default='norm', help='Gradient clipping mode. One of ("norm", "value", "agc")') # Learning rate schedule parameters parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "step"') parser.add_argument('--lr', type=float, default=0.05, metavar='LR', help='learning rate (default: 0.05)') parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', help='learning rate cycle len multiplier (default: 1.0)') parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', help='amount to decay each learning rate cycle (default: 0.5)') parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', help='learning rate cycle limit, cycles enabled if > 1') parser.add_argument('--lr-k-decay', type=float, default=1.0, help='learning rate k-decay for cosine/poly (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--epochs', type=int, default=300, metavar='N', help='number of epochs to train (default: 300)') parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Augmentation & regularization parameters parser.add_argument('--no-aug', action='store_true', default=False, help='Disable all training augmentation, override other train aug args') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), parser.add_argument('--aug-repeats', type=int, default=0, help='Number of augmentation repetitions (distributed training only) (default: 0)') parser.add_argument('--aug-splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 0 or >=2)') parser.add_argument('--jsd-loss', action='store_true', default=False, help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') parser.add_argument('--bce-loss', action='store_true', default=False, help='Enable BCE loss w/ Mixup/CutMix use.') parser.add_argument('--bce-target-thresh', type=float, default=None, help='Threshold for binarizing softened BCE targets (default: None, disabled)') parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--cutmix', type=float, default=0.0, help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') parser.add_argument('--mixup-prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--mixup-switch-prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') parser.add_argument('--mixup-mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='Turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--train-interpolation', type=str, default='random', help='Training interpolation (random, bilinear, bicubic default: "random")') parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', help='Drop connect rate, DEPRECATED, use drop-path (default: None)') parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', help='Drop path rate (default: None)') parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', help='Drop block rate (default: None)') # Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-tf', action='store_true', default=False, help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') parser.add_argument('--sync-bn', action='store_true', help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') parser.add_argument('--dist-bn', type=str, default='reduce', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') parser.add_argument('--split-bn', action='store_true', help='Enable separate BN layers per augmentation split.') # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') parser.add_argument('--model-ema-decay', type=float, default=0.9998, help='decay factor for model weights moving average (default: 0.9998)') # Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--worker-seeding', type=str, default='all', help='worker seed mode (default: all)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') parser.add_argument('--checkpoint-hist', type=int, default=1, metavar='N', help='number of checkpoints to keep (default: 10)') parser.add_argument('-j', '--workers', type=int, default=10, metavar='N', help='how many training processes to use (default: 4)') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA Apex AMP or Native AMP for mixed precision training') parser.add_argument('--apex-amp', action='store_true', default=False, help='Use NVIDIA Apex AMP mixed precision') parser.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') parser.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no-prefetcher', action='store_true', default=True, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') parser.add_argument('--experiment', default='', type=str, metavar='NAME', help='name of train experiment, name of sub-folder for output') parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', help='Best metric (default: "top1"') parser.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument("--local_rank", default=0, type=int) parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') def _parse_args(): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() if args_config.config: with open(args_config.config, 'r') as f: cfg = yaml.safe_load(f) parser.set_defaults(**cfg) # The main arg parser parses the rest of the args, the usual # defaults will have been overridden if config file specified. args = parser.parse_args(remaining) # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) return args, args_text def load_init_checkpoint(model, checkpoint_path): if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: _logger.info('Restoring model state from checkpoint...') new_state_dict = OrderedDict() # model_state_dict = model.state_dict() for k, v in checkpoint['state_dict'].items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v # model_state_dict[name] = v model.load_state_dict(new_state_dict) _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) else: model.load_state_dict(checkpoint) _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) else: _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning("You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly model.eval() flops = FlopCountAnalysis(model, torch.rand(1, 3, 384, 384)) if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') print(flop_count_table(flops)) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: if os.path.exists(args.resume): resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0, ) if args.initial_checkpoint: if os.path.exists(args.initial_checkpoint): load_init_checkpoint( model, args.initial_checkpoint, ) # memory=student_mem # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset( args.dataset, root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) if collate_fn is not None: print('collate_fn is not none') if mixup_fn is not None: print('mixup_fn is not none') # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size or args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: entropy_thr = 0 for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics, all_entropy = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, entropy_thr=entropy_thr) # all_entropy = torch.stack(all_entropy, dim=0) # entropy_thr = all_entropy.mean() # entropy_thr = entropy_thr * 2.0 if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) def train_one_epoch( epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, loss_scaler=None, model_ema=None, mixup_fn=None, entropy_thr=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: loader.mixup_enabled = False elif mixup_fn is not None: mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() model.train() all_entropy = [] end = time.time() last_idx = len(loader) - 1 num_updates = epoch * len(loader) for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: input, target = input.cuda(), target.cuda() if mixup_fn is not None: input_mix, target_mix = mixup_fn(input, target) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): output = model(input_mix) # output = model(input) loss = loss_fn(output, target_mix) if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if loss_scaler is not None: loss_scaler( loss, optimizer, clip_grad=args.clip_grad, clip_mode=args.clip_mode, parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), create_graph=second_order) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: dispatch_clip_grad( model_parameters(model, exclude_head='agc' in args.clip_mode), value=args.clip_grad, mode=args.clip_mode) optimizer.step() if model_ema is not None: model_ema.update(model) torch.cuda.synchronize() num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: _logger.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'LR: {lr:.3e} ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, batch_idx, len(loader), 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, data_time=data_time_m)) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): saver.save_recovery(epoch, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() # end for if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() return OrderedDict([('loss', losses_m.avg)]), all_entropy # utils @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): batch_time_m = AverageMeter() losses_m = AverageMeter() top1_m = AverageMeter() top5_m = AverageMeter() model.eval() end = time.time() last_idx = len(loader) - 1 with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if not args.prefetcher: input = input.cuda() target = target.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): output = model(input) # print(output.size()) if isinstance(output, (tuple, list)): output = output[0] # augmentation reduction reduce_factor = args.tta if reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) acc1 = reduce_tensor(acc1, args.world_size) acc5 = reduce_tensor(acc5, args.world_size) else: reduced_loss = loss.data torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0)) batch_time_m.update(time.time() - end) end = time.time() if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( '{0}: [{1:>4d}/{2}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=top1_m, top5=top5_m)) metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) return metrics if __name__ == '__main__': main() ================================================ FILE: models/__init__.py ================================================ from .inception_transformer import * ================================================ FILE: models/inception_transformer.py ================================================ # Copyright 2022 Garena Online Private Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Inception transformer implementation. Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models). """ import math import logging from functools import partial from collections import OrderedDict from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from timm.models.registry import register_model from torch.nn.init import _calculate_fan_in_and_fan_out import math import warnings from timm.models.layers.helpers import to_2tuple _logger = logging.getLogger(__name__) def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } # default_cfgs = { # 'iformer_224': _cfg(), # 'iformer_384': _cfg(input_size=(3, 384, 384), crop_pct=1.0), # } default_cfgs = { 'iformer_small': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_small.pth'), 'iformer_base': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_base.pth'), 'iformer_large': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_large.pth'), 'iformer_small_384': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_small_384.pth', input_size=(3, 384, 384), crop_pct=1.0), 'iformer_base_384': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_base_384.pth', input_size=(3, 384, 384), crop_pct=1.0), 'iformer_large_384': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_large_384.pth', input_size=(3, 384, 384), crop_pct=1.0), } def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # type: (Tensor, float, float, float, float) -> Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == 'fan_in': denom = fan_in elif mode == 'fan_out': denom = fan_out elif mode == 'fan_avg': denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) elif distribution == "normal": tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, img_size=224, kernel_size=16, stride=16, padding=0, in_chans=3, embed_dim=768): super().__init__() kernel_size = to_2tuple(kernel_size) stride = to_2tuple(stride) padding = to_2tuple(padding) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) self.norm = nn.BatchNorm2d(embed_dim) def forward(self, x): # B, C, H, W = x.shape x = self.proj(x) x = self.norm(x) x = x.permute(0,2,3,1) return x class FirstPatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, kernel_size=3, stride=2, padding=1, in_chans=3, embed_dim=768): super().__init__() self.proj1 = nn.Conv2d(in_chans, embed_dim//2, kernel_size=kernel_size, stride=stride, padding=padding ) self.norm1 = nn.BatchNorm2d(embed_dim // 2) self.gelu1 = nn.GELU() self.proj2 = nn.Conv2d(embed_dim//2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding ) self.norm2 = nn.BatchNorm2d(embed_dim) def forward(self, x): # B, C, H, W = x.shape x = self.proj1(x) x = self.norm1(x) x = self.gelu1(x) x = self.proj2(x) x = self.norm2(x) x = x.permute(0,2,3,1) return x class HighMixer(nn.Module): def __init__(self, dim, kernel_size=3, stride=1, padding=1, **kwargs, ): super().__init__() self.cnn_in = cnn_in = dim // 2 self.pool_in = pool_in = dim // 2 self.cnn_dim = cnn_dim = cnn_in * 2 self.pool_dim = pool_dim = pool_in * 2 self.conv1 = nn.Conv2d(cnn_in, cnn_dim, kernel_size=1, stride=1, padding=0, bias=False) self.proj1 = nn.Conv2d(cnn_dim, cnn_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, groups=cnn_dim) self.mid_gelu1 = nn.GELU() self.Maxpool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) self.proj2 = nn.Conv2d(pool_in, pool_dim, kernel_size=1, stride=1, padding=0) self.mid_gelu2 = nn.GELU() def forward(self, x): # B, C H, W cx = x[:,:self.cnn_in,:,:].contiguous() cx = self.conv1(cx) cx = self.proj1(cx) cx = self.mid_gelu1(cx) px = x[:,self.cnn_in:,:,:].contiguous() px = self.Maxpool(px) px = self.proj2(px) px = self.mid_gelu2(px) hx = torch.cat((cx, px), dim=1) return hx class LowMixer(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., pool_size=2, **kwargs, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.dim = dim self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.pool = nn.AvgPool2d(pool_size, stride=pool_size, padding=0, count_include_pad=False) if pool_size > 1 else nn.Identity() self.uppool = nn.Upsample(scale_factor=pool_size) if pool_size > 1 else nn.Identity() def att_fun(self, q, k, v, B, N, C): attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = (attn @ v).transpose(2, 3).reshape(B, C, N) return x def forward(self, x): # B, C, H, W B, _, _, _ = x.shape xa = self.pool(x) xa = xa.permute(0, 2, 3, 1).view(B, -1, self.dim) B, N, C = xa.shape qkv = self.qkv(xa).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) xa = self.att_fun(q, k, v, B, N, C) xa = xa.view(B, C, int(N**0.5), int(N**0.5))#.permute(0, 3, 1, 2) xa = self.uppool(xa) return xa class Mixer(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., attention_head=1, pool_size=2, **kwargs, ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim = dim // num_heads self.low_dim = low_dim = attention_head * head_dim self.high_dim = high_dim = dim - low_dim self.high_mixer = HighMixer(high_dim) self.low_mixer = LowMixer(low_dim, num_heads=attention_head, qkv_bias=qkv_bias, attn_drop=attn_drop, pool_size=pool_size,) self.conv_fuse = nn.Conv2d(low_dim+high_dim*2, low_dim+high_dim*2, kernel_size=3, stride=1, padding=1, bias=False, groups=low_dim+high_dim*2) self.proj = nn.Conv2d(low_dim+high_dim*2, dim, kernel_size=1, stride=1, padding=0) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, H, W, C = x.shape x = x.permute(0, 3, 1, 2) hx = x[:,:self.high_dim,:,:].contiguous() hx = self.high_mixer(hx) lx = x[:,self.high_dim:,:,:].contiguous() lx = self.low_mixer(lx) x = torch.cat((hx, lx), dim=1) x = x + self.conv_fuse(x) x = self.proj(x) x = self.proj_drop(x) x = x.permute(0, 2, 3, 1).contiguous() return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_head=1, pool_size=2, attn=Mixer, use_layer_scale=False, layer_scale_init_value=1e-5, ): super().__init__() self.norm1 = norm_layer(dim) self.attn = attn(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, attention_head=attention_head, pool_size=pool_size,) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.use_layer_scale = use_layer_scale if self.use_layer_scale: # print('use layer scale init value {}'.format(layer_scale_init_value)) self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) def forward(self, x): if self.use_layer_scale: x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x))) x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class InceptionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=None, depths=None, num_heads=None, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init='', attention_heads=None, use_layer_scale=False, layer_scale_init_value=1e-5, checkpoint_path=None, **kwargs, ): super().__init__() st2_idx = sum(depths[:1]) st3_idx = sum(depths[:2]) st4_idx = sum(depths[:3]) depth = sum(depths) self.num_classes = num_classes norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.patch_embed = FirstPatchEmbed(in_chans=in_chans, embed_dim=embed_dims[0]) self.num_patches1 = num_patches = img_size // 4 self.pos_embed1 = nn.Parameter(torch.zeros(1, num_patches, num_patches, embed_dims[0])) self.blocks1 = nn.Sequential(*[ Block( dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attention_head=attention_heads[i], pool_size=2,) # use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, # ) for i in range(0, st2_idx)]) self.patch_embed2 = embed_layer(kernel_size=3, stride=2, padding=1, in_chans=embed_dims[0], embed_dim=embed_dims[1]) self.num_patches2 = num_patches = num_patches // 2 self.pos_embed2 = nn.Parameter(torch.zeros(1, num_patches, num_patches, embed_dims[1])) self.blocks2 = nn.Sequential(*[ Block( dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attention_head=attention_heads[i], pool_size=2,) # use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, channel_layer_scale=channel_layer_scale, # ) for i in range(st2_idx,st3_idx)]) self.patch_embed3 = embed_layer(kernel_size=3, stride=2, padding=1, in_chans=embed_dims[1], embed_dim=embed_dims[2]) self.num_patches3 = num_patches = num_patches // 2 self.pos_embed3 = nn.Parameter(torch.zeros(1, num_patches, num_patches, embed_dims[2])) self.blocks3= nn.Sequential(*[ Block( dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attention_head=attention_heads[i], pool_size=1, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, ) for i in range(st3_idx, st4_idx)]) self.patch_embed4 = embed_layer(kernel_size=3, stride=2, padding=1, in_chans=embed_dims[2], embed_dim=embed_dims[3]) self.num_patches4 = num_patches = num_patches // 2 self.pos_embed4 = nn.Parameter(torch.zeros(1, num_patches, num_patches, embed_dims[3])) self.blocks4 = nn.Sequential(*[ Block( dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attention_head=attention_heads[i], pool_size=1, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, ) for i in range(st4_idx,depth)]) self.norm = norm_layer(embed_dims[-1]) # Classifier head(s) self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() # set post block, for example, class attention layers self.init_weights(weight_init) def init_weights(self, mode=''): trunc_normal_(self.pos_embed1, std=.02) trunc_normal_(self.pos_embed2, std=.02) trunc_normal_(self.pos_embed3, std=.02) trunc_normal_(self.pos_embed4, std=.02) self.apply(_init_vit_weights) def _init_weights(self, m): # this fn left here for compat with downstream users _init_vit_weights(m) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token', 'dist_token'} def get_classifier(self): if self.dist_token is None: return self.head else: return self.head, self.head_dist def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.num_tokens == 2: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() def _get_pos_embed(self, pos_embed, num_patches_def, H, W): if H * W == num_patches_def * num_patches_def: return pos_embed else: return F.interpolate( pos_embed.permute(0, 3, 1, 2), size=(H, W), mode="bilinear").permute(0, 2, 3, 1) def forward_features(self, x): x = self.patch_embed(x) B, H, W, C = x.shape x = x + self._get_pos_embed(self.pos_embed1, self.num_patches1, H, W) x = self.blocks1(x) x = x.permute(0, 3, 1, 2) x = self.patch_embed2(x) B, H, W, C = x.shape x = x + self._get_pos_embed(self.pos_embed2, self.num_patches2, H, W) x = self.blocks2(x) x = x.permute(0, 3, 1, 2) x = self.patch_embed3(x) B, H, W, C = x.shape x = x + self._get_pos_embed(self.pos_embed3, self.num_patches3, H, W) x = self.blocks3(x) x = x.permute(0, 3, 1, 2) x = self.patch_embed4(x) B, H, W, C = x.shape x = x + self._get_pos_embed(self.pos_embed4, self.num_patches4, H, W) x = self.blocks4(x) x = x.flatten(1,2) x = self.norm(x) return x.mean(1) def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0.): """ ViT weight initialization * When called without n, head_bias, jax_impl args it will behave exactly the same as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). """ if isinstance(module, nn.Linear): if name.startswith('head'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) elif name.startswith('pre_logits'): lecun_normal_(module.weight) nn.init.zeros_(module.bias) else: trunc_normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.zeros_(module.bias) nn.init.ones_(module.weight) elif isinstance(module, nn.Conv2d): trunc_normal_(module.weight, std=.02) if module.bias is not None: nn.init.constant_(module.bias, 0) @register_model def iformer_small(pretrained=False, **kwargs): """ 19.866M 4.849G 83.382 """ depths = [3, 3, 9, 3] embed_dims = [96, 192, 320, 384] num_heads = [3, 6, 10, 12] attention_heads = [1]*3 + [3]*3 + [7] * 4 + [9] * 5 + [11] * 3 model = InceptionTransformer(img_size=224, depths=depths, embed_dims=embed_dims, num_heads=num_heads, attention_heads=attention_heads, use_layer_scale=True, layer_scale_init_value=1e-6, **kwargs) model.default_cfg = default_cfgs['iformer_small'] if pretrained: url = model.default_cfg['url'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) model.load_state_dict(checkpoint) return model @register_model def iformer_small_384(pretrained=False, **kwargs): depths = [3, 3, 9, 3] embed_dims = [96, 192, 320, 384] num_heads = [3, 6, 10, 12] attention_heads = [1]*3 + [3]*3 + [7] * 4 + [9] * 5 + [11] * 3 model = InceptionTransformer(img_size=384, depths=depths, embed_dims=embed_dims, num_heads=num_heads, attention_heads=attention_heads, use_layer_scale=True, layer_scale_init_value=1e-6, **kwargs) model.default_cfg = default_cfgs['iformer_small_384'] if pretrained: url = model.default_cfg['url'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) model.load_state_dict(checkpoint) return model @register_model def iformer_base(pretrained=False, **kwargs): """ 47.866M 9.379G 84.598 """ depths = [4, 6, 14, 6] embed_dims = [96, 192, 384, 512] num_heads = [3, 6, 12, 16] attention_heads = [1]*4 + [3]*6 + [8] * 7 + [10] * 7 + [15] * 6 model = InceptionTransformer(img_size=224, depths=depths, embed_dims=embed_dims, num_heads=num_heads, attention_heads=attention_heads, use_layer_scale=True, layer_scale_init_value=1e-6, **kwargs) model.default_cfg = default_cfgs['iformer_base'] if pretrained: url = model.default_cfg['url'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) model.load_state_dict(checkpoint) return model @register_model def iformer_base_384(pretrained=False, **kwargs): depths = [4, 6, 14, 6] embed_dims = [96, 192, 384, 512] num_heads = [3, 6, 12, 16] attention_heads = [1]*4 + [3]*6 + [8] * 7 + [10] * 7 + [15] * 6 model = InceptionTransformer(img_size=384, depths=depths, embed_dims=embed_dims, num_heads=num_heads, attention_heads=attention_heads, use_layer_scale=True, layer_scale_init_value=1e-6, **kwargs) model.default_cfg = default_cfgs['iformer_base_384'] if pretrained: url = model.default_cfg['url'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) model.load_state_dict(checkpoint) return model @register_model def iformer_large(pretrained=False, **kwargs): """ 86.637M 14.048G 84.752 """ depths = [4, 6, 18, 8] embed_dims = [96, 192, 448, 640] num_heads = [3, 6, 14, 20] attention_heads = [1]*4 + [3]*6 + [10] * 9 + [12] * 9 + [19] * 8 model = InceptionTransformer(img_size=224, depths=depths, embed_dims=embed_dims, num_heads=num_heads, attention_heads=attention_heads, use_layer_scale=True, layer_scale_init_value=1e-6, **kwargs) model.default_cfg = default_cfgs['iformer_large'] if pretrained: url = model.default_cfg['url'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) model.load_state_dict(checkpoint) return model @register_model def iformer_large_384(pretrained=False, **kwargs): depths = [4, 6, 18, 8] embed_dims = [96, 192, 448, 640] num_heads = [3, 6, 14, 20] attention_heads = [1]*4 + [3]*6 + [10] * 9 + [12] * 9 + [19] * 8 model = InceptionTransformer(img_size=384, depths=depths, embed_dims=embed_dims, num_heads=num_heads, attention_heads=attention_heads, use_layer_scale=True, layer_scale_init_value=1e-6, **kwargs) model.default_cfg = default_cfgs['iformer_large_384'] if pretrained: url = model.default_cfg['url'] checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) model.load_state_dict(checkpoint) return model ================================================ FILE: setup.cfg ================================================ [dist_conda] conda_name_differences = 'torch:pytorch' channels = pytorch noarch = True ================================================ FILE: train.py ================================================ #!/usr/bin/env python3 """ ImageNet Training Script This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet training results with some of the latest networks and training techniques. It favours canonical PyTorch and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. This script was started from an early version of the PyTorch ImageNet example (https://github.com/pytorch/examples/tree/master/imagenet) NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples (https://github.com/NVIDIA/apex/tree/master/examples/imagenet) Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse import time # from types import _KT import yaml import os import logging from collections import OrderedDict from contextlib import suppress from datetime import datetime import torch import torch.nn as nn import torch.nn.functional as F import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import * from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler from fvcore.nn import FlopCountAnalysis from fvcore.nn import flop_count_table import models try: from apex import amp from apex.parallel import DistributedDataParallel as ApexDDP from apex.parallel import convert_syncbn_model has_apex = True except ImportError: has_apex = False has_native_amp = False try: if getattr(torch.cuda.amp, 'autocast') is not None: has_native_amp = True except AttributeError: pass try: import wandb has_wandb = True except ImportError: has_wandb = False torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') # The first arg parser parses out only the --config argument, this argument is used to # load a yaml file containing key-values that override the defaults for the main parser below config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', help='YAML config file specifying default arguments') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # ============self parameters========================== # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='path to dataset') parser.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') parser.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') parser.add_argument('--val-split', metavar='NAME', default='validation', help='dataset validation split (default: validation)') parser.add_argument('--dataset-download', action='store_true', default=False, help='Allow download of dataset for torch/ and tfds/ datasets that support it.') parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') # Model parameters parser.add_argument('--model', default='deit_small_patch16_224', type=str, metavar='MODEL', help='Name of model to train (default: "resnet50"') parser.add_argument('--pretrained', action='store_true', default=False, help='Start with pretrained version of specified network (if avail)') parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', help='Initialize model from this checkpoint (default: none)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='Resume full model and optimizer state from checkpoint (default: none)') parser.add_argument('--no-resume-opt', action='store_true', default=False, help='prevent resume of optimizer state when resuming model') parser.add_argument('--num-classes', type=int, default=None, metavar='N', help='number of label classes (Model default if None)') parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', help='Override std deviation of of dataset') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', help='validation batch size override (default: None)') # Optimizer parameters parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: None, use opt default)') parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='Optimizer momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=2e-5, help='weight decay (default: 2e-5)') parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--clip-mode', type=str, default='norm', help='Gradient clipping mode. One of ("norm", "value", "agc")') # Learning rate schedule parameters parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "step"') parser.add_argument('--lr', type=float, default=0.05, metavar='LR', help='learning rate (default: 0.05)') parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', help='learning rate cycle len multiplier (default: 1.0)') parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', help='amount to decay each learning rate cycle (default: 0.5)') parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', help='learning rate cycle limit, cycles enabled if > 1') parser.add_argument('--lr-k-decay', type=float, default=1.0, help='learning rate k-decay for cosine/poly (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--epochs', type=int, default=300, metavar='N', help='number of epochs to train (default: 300)') parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Augmentation & regularization parameters parser.add_argument('--no-aug', action='store_true', default=False, help='Disable all training augmentation, override other train aug args') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), parser.add_argument('--aug-repeats', type=int, default=0, help='Number of augmentation repetitions (distributed training only) (default: 0)') parser.add_argument('--aug-splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 0 or >=2)') parser.add_argument('--jsd-loss', action='store_true', default=False, help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') parser.add_argument('--bce-loss', action='store_true', default=False, help='Enable BCE loss w/ Mixup/CutMix use.') parser.add_argument('--bce-target-thresh', type=float, default=None, help='Threshold for binarizing softened BCE targets (default: None, disabled)') parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.8, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--cutmix', type=float, default=1.0, help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') parser.add_argument('--mixup-prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--mixup-switch-prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') parser.add_argument('--mixup-mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='Turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--train-interpolation', type=str, default='random', help='Training interpolation (random, bilinear, bicubic default: "random")') parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', help='Drop connect rate, DEPRECATED, use drop-path (default: None)') parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', help='Drop path rate (default: None)') parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', help='Drop block rate (default: None)') # Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-tf', action='store_true', default=False, help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') parser.add_argument('--sync-bn', action='store_true', help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') parser.add_argument('--dist-bn', type=str, default='reduce', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') parser.add_argument('--split-bn', action='store_true', help='Enable separate BN layers per augmentation split.') # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') parser.add_argument('--model-ema-decay', type=float, default=0.9998, help='decay factor for model weights moving average (default: 0.9998)') # Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--worker-seeding', type=str, default='all', help='worker seed mode (default: all)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') parser.add_argument('--checkpoint-hist', type=int, default=1, metavar='N', help='number of checkpoints to keep (default: 10)') parser.add_argument('-j', '--workers', type=int, default=10, metavar='N', help='how many training processes to use (default: 4)') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA Apex AMP or Native AMP for mixed precision training') parser.add_argument('--apex-amp', action='store_true', default=False, help='Use NVIDIA Apex AMP mixed precision') parser.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') parser.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no-prefetcher', action='store_true', default=True, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') parser.add_argument('--experiment', default='', type=str, metavar='NAME', help='name of train experiment, name of sub-folder for output') parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', help='Best metric (default: "top1"') parser.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument("--local_rank", default=0, type=int) parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') def _parse_args(): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() if args_config.config: with open(args_config.config, 'r') as f: cfg = yaml.safe_load(f) parser.set_defaults(**cfg) # The main arg parser parses the rest of the args, the usual # defaults will have been overridden if config file specified. args = parser.parse_args(remaining) # Cache the args as a text string to save them in the output dir later args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) return args, args_text def main(): setup_default_logging() args, args_text = _parse_args() if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) else: _logger.warning("You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: # `--amp` chooses native amp before apex (APEX ver not actively maintained) if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: use_amp = 'native' elif args.apex_amp or args.native_amp: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") random_seed(args.seed, args.rank) model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly model.eval() flops = FlopCountAnalysis(model, torch.rand(1, 3, 224, 224)) if args.local_rank == 0: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') print(flop_count_table(flops)) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 if args.aug_splits > 0: assert args.aug_splits > 1, 'A split of 1 makes no sense' num_aug_splits = args.aug_splits # enable split bn (separate bn stats per batch-portion) if args.split_bn: assert num_aug_splits > 1 or args.resplit model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set model.cuda() if args.channels_last: model = model.to(memory_format=torch.channels_last) # setup synchronized BatchNorm for distributed training # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN preferred unless native amp is activated model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing loss_scaler = None if use_amp == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() if args.local_rank == 0: _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() if args.local_rank == 0: _logger.info('Using native Torch AMP. Training in mixed precision.') else: if args.local_rank == 0: _logger.info('AMP not enabled. Training in float32.') # optionally resume from a checkpoint resume_epoch = None if args.resume: if os.path.exists(args.resume): resume_epoch = resume_checkpoint( model, args.resume, optimizer=None if args.no_resume_opt else optimizer, loss_scaler=None if args.no_resume_opt else loss_scaler, log_info=args.local_rank == 0, ) # setup exponential moving average of model weights, SWA could be used here too model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper model_ema = ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training if args.distributed: if has_apex and use_amp == 'apex': # Apex DDP preferred unless native amp is activated if args.local_rank == 0: _logger.info("Using NVIDIA APEX DistributedDataParallel.") model = ApexDDP(model, delay_allreduce=True) else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch lr_scheduler, num_epochs = create_scheduler(args, optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets dataset_train = create_dataset( args.dataset, root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) collate_fn = FastCollateMixup(**mixup_args) else: mixup_fn = Mixup(**mixup_args) if collate_fn is not None: print('collate_fn is not none') if mixup_fn is not None: print('mixup_fn is not none') # wrap dataset in AugMix helper if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) # create data loaders w/ augmentation pipeiine train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] loader_train = create_loader( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, re_count=args.recount, re_split=args.resplit, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader, worker_seeding=args.worker_seeding, ) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], batch_size=args.validation_batch_size or args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) # setup loss function if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None output_dir = None if args.rank == 0: if args.experiment: exp_name = args.experiment else: exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) try: entropy_thr = 0 for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics, all_entropy = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, entropy_thr=entropy_thr) # all_entropy = torch.stack(all_entropy, dim=0) # entropy_thr = all_entropy.mean() # entropy_thr = entropy_thr * 2.0 if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) except KeyboardInterrupt: pass if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) def train_one_epoch( epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, loss_scaler=None, model_ema=None, mixup_fn=None, entropy_thr=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: loader.mixup_enabled = False elif mixup_fn is not None: mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order batch_time_m = AverageMeter() data_time_m = AverageMeter() losses_m = AverageMeter() model.train() all_entropy = [] end = time.time() last_idx = len(loader) - 1 num_updates = epoch * len(loader) for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) if not args.prefetcher: input, target = input.cuda(), target.cuda() if mixup_fn is not None: input_mix, target_mix = mixup_fn(input, target) if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): output = model(input_mix) loss = loss_fn(output, target_mix) if not args.distributed: losses_m.update(loss.item(), input.size(0)) optimizer.zero_grad() if loss_scaler is not None: loss_scaler( loss, optimizer, clip_grad=args.clip_grad, clip_mode=args.clip_mode, parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), create_graph=second_order) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: dispatch_clip_grad( model_parameters(model, exclude_head='agc' in args.clip_mode), value=args.clip_grad, mode=args.clip_mode) optimizer.step() if model_ema is not None: model_ema.update(model) torch.cuda.synchronize() num_updates += 1 batch_time_m.update(time.time() - end) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: _logger.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'LR: {lr:.3e} ' 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, batch_idx, len(loader), 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) * args.world_size / batch_time_m.val, rate_avg=input.size(0) * args.world_size / batch_time_m.avg, lr=lr, data_time=data_time_m)) if args.save_images and output_dir: torchvision.utils.save_image( input, os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), padding=0, normalize=True) if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): saver.save_recovery(epoch, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) end = time.time() # end for if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() return OrderedDict([('loss', losses_m.avg)]), all_entropy # utils @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): batch_time_m = AverageMeter() losses_m = AverageMeter() top1_m = AverageMeter() top5_m = AverageMeter() model.eval() end = time.time() last_idx = len(loader) - 1 with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx if not args.prefetcher: input = input.cuda() target = target.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) with amp_autocast(): output = model(input) # print(output.size()) if isinstance(output, (tuple, list)): output = output[0] # augmentation reduction reduce_factor = args.tta if reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) acc1, acc5 = accuracy(output, target, topk=(1, 5)) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) acc1 = reduce_tensor(acc1, args.world_size) acc5 = reduce_tensor(acc5, args.world_size) else: reduced_loss = loss.data torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) top1_m.update(acc1.item(), output.size(0)) top5_m.update(acc5.item(), output.size(0)) batch_time_m.update(time.time() - end) end = time.time() if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info( '{0}: [{1:>4d}/{2}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=top1_m, top5=top5_m)) metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) return metrics if __name__ == '__main__': main() ================================================ FILE: validate.py ================================================ #!/usr/bin/env python3 """ ImageNet Validation Script This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit. Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse import os import csv import glob import time import logging import torch import torch.nn as nn import torch.nn.parallel from collections import OrderedDict from contextlib import suppress from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy has_apex = False try: from apex import amp has_apex = True except ImportError: pass has_native_amp = False try: if getattr(torch.cuda.amp, 'autocast') is not None: has_native_amp = True except AttributeError: pass torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') # ============self parameters========================== parser.add_argument('--K', default=1024, type=int, help="memory size 1536、2560、7680 1024") parser.add_argument('--embed_dim', default=384, type=int, help="mem_depth, 3, 6") # ============self parameters========================== parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') parser.add_argument('--split', metavar='NAME', default='validation', help='dataset split (default: validation)') parser.add_argument('--dataset-download', action='store_true', default=False, help='Allow download of dataset for torch/ and tfds/ datasets that support it.') parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', help='model architecture (default: dpn92)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop pct') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', help='Override std deviation of of dataset') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--num-classes', type=int, default=None, help='Number classes in dataset') parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--log-freq', default=10, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') parser.add_argument('--test-pool', dest='test_pool', action='store_true', help='enable test time pool') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--amp', action='store_true', default=False, help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') parser.add_argument('--apex-amp', action='store_true', default=False, help='Use NVIDIA Apex AMP mixed precision') parser.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true', help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', help='Real labels JSON file for imagenet evaluation') parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', help='Valid label indices txt file for validation of partial label space') def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: if has_native_amp: args.native_amp = True elif has_apex: args.apex_amp = True else: _logger.warning("Neither APEX or Native Torch AMP is available.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast _logger.info('Validating in mixed precision with native PyTorch AMP.') elif args.apex_amp: _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') else: _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy() # ======================== args.mem_idx = [10] # args.mem_idx = list(range(1,13)) # mem_idx = args.mem_idx # student_mem = vit_mem(args.embed_dim, mem_idx=args.mem_idx, K=args.K, top_n=args.knn, n_center=args.n_center) # student_mem = student_mem.cuda() # ======================== # create model model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, global_pool=args.gp, scriptable=args.torchscript, mem_index=args.mem_idx, K=args.K) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False if args.test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) model = model.cuda() if args.apex_amp: model = amp.initialize(model, opt_level='O1') if args.channels_last: model = model.to(memory_format=torch.channels_last) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: valid_labels = {int(line.rstrip()) for line in f} valid_labels = [i in valid_labels for i in range(args.num_classes)] else: valid_labels = None if args.real_labels: real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) else: real_labels = None crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input) end = time.time() for batch_idx, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) # compute output with amp_autocast(): output = model(input) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output) # measure accuracy and record loss acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(acc1.item(), input.size(0)) top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if batch_idx % args.log_freq == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( batch_idx, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: top1a, top5a = top1.avg, top5.avg results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results def main(): setup_default_logging() args = parser.parse_args() model_cfgs = [] model_names = [] if os.path.isdir(args.checkpoint): # validate all checkpoints in a path with same model checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') checkpoints += glob.glob(args.checkpoint + '/*.pth') model_names = list_models(args.model) model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] else: if args.model == 'all': # validate all models in a list of names with pretrained checkpoints args.pretrained = True model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k']) model_cfgs = [(n, '') for n in model_names] elif not is_model(args.model): # model name doesn't exist, try as wildcard filter model_names = list_models(args.model) model_cfgs = [(n, '') for n in model_names] if not model_cfgs and os.path.isfile(args.model): with open(args.model) as f: model_names = [line.rstrip() for line in f] model_cfgs = [(n, None) for n in model_names if n] if len(model_cfgs): results_file = args.results_file or './results-all.csv' _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) results = [] try: start_batch_size = args.batch_size for m, c in model_cfgs: batch_size = start_batch_size args.model = m args.checkpoint = c result = OrderedDict(model=args.model) r = {} while not r and batch_size >= args.num_gpu: torch.cuda.empty_cache() try: args.batch_size = batch_size print('Validating with batch size: %d' % args.batch_size) r = validate(args) except RuntimeError as e: if batch_size <= args.num_gpu: print("Validation failed with no ability to reduce batch size. Exiting.") raise e batch_size = max(batch_size // 2, args.num_gpu) print("Validation failed, reducing batch size by 50%") result.update(r) if args.checkpoint: result['checkpoint'] = args.checkpoint results.append(result) except KeyboardInterrupt as e: pass results = sorted(results, key=lambda x: x['top1'], reverse=True) if len(results): write_results(results_file, results) else: validate(args) def write_results(results_file, results): with open(results_file, mode='w') as cf: dw = csv.DictWriter(cf, fieldnames=results[0].keys()) dw.writeheader() for r in results: dw.writerow(r) cf.flush() if __name__ == '__main__': main()