Showing preview only (238K chars total). Download the full file or copy to clipboard to get everything.
Repository: sail-sg/iFormer
Branch: main
Commit: 725d8e7f455b
Files: 21
Total size: 228.8 KB
Directory structure:
gitextract_p3kgg1_k/
├── LICENSE
├── MANIFEST.in
├── README.md
├── checkpoint/
│ ├── iformer_base/
│ │ ├── args.yaml
│ │ └── summary.csv
│ ├── iformer_large/
│ │ ├── args.yaml
│ │ └── summary.csv
│ └── iformer_small/
│ ├── args.yaml
│ └── summary.csv
├── checkpoint_384/
│ ├── iformer_base_384/
│ │ ├── args.yaml
│ │ └── summary.csv
│ ├── iformer_large_384/
│ │ ├── args.yaml
│ │ └── summary.csv
│ └── iformer_small_384/
│ ├── args.yaml
│ └── summary.csv
├── fine-tune.py
├── models/
│ ├── __init__.py
│ └── inception_transformer.py
├── setup.cfg
├── train.py
└── validate.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2019 Ross Wightman
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MANIFEST.in
================================================
include timm/models/pruned/*.txt
================================================
FILE: README.md
================================================
# iFormer: [Inception Transformer](http://arxiv.org/abs/2205.12956) (NeurIPS 2022 Oral)
This is a PyTorch implementation of iFormer proposed by our paper "[Inception Transformer](http://arxiv.org/abs/2205.12956)".
## Image Classification
### 1. Requirements
torch>=1.7.0; torchvision>=0.8.1; timm==0.5.4; fvcore; [apex-amp](https://github.com/NVIDIA/apex) (if you want to use fp16);
data prepare: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4).
```
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
```
### Main results on ImageNet-1K
| Model | #params | FLOPs | Image resolution | acc@1| Model |
| :--- | :---: | :---: | :---: | :---: | :---: |
| iFormer-S | 20M | 4.8G | 224 | 83.4 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_small.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_small/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_small/summary.csv) |
| iFormer-B | 48M | 9.4G | 224 | 84.6 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_base.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_base/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_base/summary.csv) |
| iFormer-L | 87M | 14.0G | 224 | 84.8 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_large.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_large/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint/iformer_large/summary.csv) |
Fine-tuning Results with larger resolution (384x384) on ImageNet-1K
| Model | #params | FLOPs | Image resolution | acc@1| Model |
| :--- | :---: | :---: | :---: | :---: | :---: |
| iFormer-S | 20M | 16.1G | 384 | 84.6 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_small_384.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_small_384/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_small_384/summary.csv) |
| iFormer-B | 48M | 30.5G | 384 | 85.7 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_base_384.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_base_384/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_base_384/summary.csv) |
| iFormer-L | 87M | 45.3G | 384 | 85.8 | [model](https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_large_384.pth)/[config](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_large_384/args.yaml)/[log](https://github.com/sail-sg/iFormer/blob/main/checkpoint_384/iformer_large_384/summary.csv) |
### Training
Train iformer_small on 224
```bash
python -m torch.distributed.launch --nproc_per_node=8 train.py /dataset/imagenet \
--model iformer_small -b 128 --epochs 300 --img-size 224 --drop-path 0.2 --lr 1e-3 \
--weight-decay 0.05 --aa rand-m9-mstd0.5-inc1 --warmup-lr 1e-6 --warmup-epochs 5 \
--output checkpoint --min-lr 1e-6 --experiment iformer_small
```
Finetune on 384 based on the pretrained checkpoint on 224
```bash
python -m torch.distributed.launch --nproc_per_node=8 fine-tune.py /dataset/imagenet \
--model iformer_small_384 -b 64 --lr 1e-5 --min-lr 1e-6 --warmup-lr 2e-8 --warmup-epochs 0 \
--epochs 20 --img-size 384 --drop-path 0.3 --weight-decay 1e-8 --mixup 0.1 --cutmix 0.1 \
--cooldown-epochs 10 --aa rand-m9-mstd0.5-inc1 --clip-grad 1.0 --output checkpoint_fine \
--initial-checkpoint checkpoint/iformer_small/model_best.pth.tar \
--experiment iformer_small_384
```
### Validation
```bash
python validate.py /dataset/imagenet --model iformer_small --checkpoint checkpoint/iformer_small/model_best.pth.tar
```
## Object Detection and Instance Segmentation
All models are based on Mask R-CNN and trained by 1x training schedule.
| Backbone | #Param. | FLOPs | box mAP | mask mAP |
|:---------:|:-------:|:-----:|:-------:|:--------:|
| iFormer-S | 40M | 263G | 46.2 | 41.9 |
| iFormer-B | 67M | 351G | 48.3 | 43.3 |
## Semantic Segmentation
| Backbone | Method | #Param. | FLOPs | mIoU |
|:---------:|---------|:-------:|:-----:|:----:|
| iFormer-S | FPN | 24M | 181G | 48.6 |
| iFormer-S | Upernet | 49M | 938G | 48.4 |
## Bibtex
```
@inproceedings{
si2022inception,
title={Inception Transformer},
author={Chenyang Si and Weihao Yu and Pan Zhou and Yichen Zhou and Xinchao Wang and Shuicheng YAN},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
```
## Acknowledgment
Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works.
[pytorch-image-models](https://github.com/rwightman/pytorch-image-models), [mmdetection](https://github.com/open-mmlab/mmdetection), [mmsegmentation](https://github.com/open-mmlab/mmsegmentation).
Besides, Weihao Yu would like to thank TPU Research Cloud (TRC) program for the support of partial computational resources.
================================================
FILE: checkpoint/iformer_base/args.yaml
================================================
aa: rand-m9-mstd0.5-inc1
amp: false
apex_amp: false
aug_repeats: 3
aug_splits: 0
batch_size: 64
bce_loss: false
bce_target_thresh: null
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 1
class_map: ''
clip_grad: 1.0
clip_mode: norm
color_jitter: 0.4
cooldown_epochs: 10
crop_pct: null
cutmix: 1.0
cutmix_minmax: null
data_dir: /dataset/imagenet-raw
dataset: ''
dataset_download: false
decay_epochs: 30.0
decay_rate: 0.1
dist_bn: reduce
drop: 0.0
drop_block: null
drop_connect: null
drop_path: 0.4
embed_dim: 384
epoch_repeats: 0.0
epochs: 300
eval_metric: top1
experiment: iformer_base
gp: null
hflip: 0.5
img_size: 224
initial_checkpoint: ''
input_size: null
interpolation: ''
jsd_loss: false
local_rank: 0
log_interval: 50
log_wandb: false
lr: 0.001
lr_cycle_decay: 0.5
lr_cycle_limit: 1
lr_cycle_mul: 1.0
lr_k_decay: 1.0
lr_noise: null
lr_noise_pct: 0.67
lr_noise_std: 1.0
mean: null
min_lr: 1.0e-05
mixup: 0.8
mixup_mode: batch
mixup_off_epoch: 0
mixup_prob: 1.0
mixup_switch_prob: 0.5
model: iformer_base
model_ema: false
model_ema_decay: 0.9998
model_ema_force_cpu: false
momentum: 0.9
native_amp: false
no_aug: false
no_ddp_bb: false
no_prefetcher: true
no_resume_opt: false
num_classes: null
opt: adamw
opt_betas: null
opt_eps: 1.0e-08
output: checkpoint
patience_epochs: 10
pin_mem: false
port: '25500'
pretrained: false
ratio:
- 0.75
- 1.3333333333333333
recount: 1
recovery_interval: 0
remode: pixel
reprob: 0.25
resplit: false
resume: ''
save_images: false
scale:
- 0.08
- 1.0
sched: cosine
seed: 42
smoothing: 0.1
split_bn: false
start_epoch: null
std: null
sync_bn: false
torchscript: false
train_interpolation: random
train_split: train
tta: 0
use_multi_epochs_loader: false
val_split: validation
validation_batch_size: null
vflip: 0.0
warmup_epochs: 5
warmup_lr: 1.0e-06
weight_decay: 0.05
worker_seeding: all
workers: 10
================================================
FILE: checkpoint/iformer_base/summary.csv
================================================
epoch,train_loss,eval_loss,eval_top1,eval_top5
0,6.908455812014067,6.860800276947021,0.3619999967956543,1.5300000119018555
1,6.665643343558679,5.8470182475280765,3.750000009765625,12.358000017089843
2,6.3399158624502325,4.92742544342041,11.036000020751953,27.733999997558595
3,6.007291243626521,4.11241763496399,20.144000025634767,42.53600006103515
4,5.740854611763587,3.553922503890991,27.132000051269532,53.624000043945315
5,5.421774644118089,3.008195775489807,37.086000013427736,64.2180000024414
6,5.238544537470891,2.725795027923584,42.46799998046875,69.79000006591797
7,5.00539570588332,2.439262678527832,48.0880001171875,74.49400002685547
8,4.882717260947595,2.2806560862350462,51.30599999267578,77.654000078125
9,4.747562701885517,2.09722495513916,54.68799995361328,80.28000007080078
10,4.535913467407227,1.9907448290634155,57.23400012695313,82.21600009033203
11,4.489044281152578,1.9909129524230957,58.418000046386716,83.2179999609375
12,4.439941516289344,1.8225921615219116,60.41000010986328,84.44999995117188
13,4.332764918987568,1.8817005332565309,61.10000019775391,85.06000002929687
14,4.261882268465483,1.6659569248199464,63.4940000024414,86.6240000024414
15,4.186580703808711,1.6334883233642579,63.90200006591797,86.71199992675781
16,4.21221111370967,1.6058608751106263,65.52199997314453,87.83800005126953
17,4.1268850198158855,1.525306345729828,65.95999991699219,88.33599999511719
18,4.047768574494582,1.518114146270752,66.62600010009766,88.70400015136718
19,4.040863623985877,1.525582196998596,67.52799990234375,89.14200002441406
20,4.001404560529268,1.484820670261383,67.73800007080078,89.16000020263672
21,3.9741526475319495,1.4245139252471923,68.33000017822266,89.64000015136719
22,3.902811747330886,1.4404746714973449,68.7499999951172,89.90000004638672
23,3.888140925994286,1.452209776649475,68.85600004882812,89.92599997070313
24,3.8569280092532816,1.355765485534668,69.79400004638671,90.53599996582031
25,3.874772942983187,1.322539180431366,69.96800012451172,90.5139998876953
26,3.824866845057561,1.3584467635345459,70.43400001953125,90.71600004638672
27,3.815111151108375,1.3259017392158507,70.89200012451172,91.10200007080078
28,3.7768686551314135,1.2894453044128418,71.04800004882813,91.17800004394532
29,3.7479228789989767,1.3424655978775024,71.19000008789062,91.07599999023438
30,3.7662950295668383,1.2917246441841126,71.43599998779297,91.60400004150391
31,3.754122550670917,1.3602314675521852,71.98599995605468,91.6720001196289
32,3.7166288174115696,1.2361285284996033,72.2459999584961,91.74399998779298
33,3.7051904568305383,1.2548866311454774,72.2139999609375,91.86599998779297
34,3.6960658752001248,1.2564913172149659,71.99000001708984,91.8040001220703
35,3.709755466534541,1.2993868688964845,72.33400013916015,91.94400011962891
36,3.695227127808791,1.2078956381416321,72.6780001196289,92.0100000390625
37,3.641167613176199,1.237737746334076,72.6480000415039,92.0580001147461
38,3.6478970601008487,1.1785106216049195,73.3219999633789,92.4960000415039
39,3.640158689939059,1.202273616809845,73.29000001708984,92.45399999023438
40,3.7028306814340444,1.2195009224510194,73.21800003662109,92.23600014404298
41,3.6155638694763184,1.2656590075683594,73.55800003417968,92.4960000390625
42,3.6116727957358727,1.2079661503982544,73.4180000415039,92.51400006591797
43,3.6359497217031627,1.1636047567939758,73.47799993408204,92.46399996337891
44,3.644525775542626,1.2462298067855835,73.5419999609375,92.73400001464844
45,3.5377747187247643,1.2198820894241333,73.96600000976562,92.71400004150391
46,3.577638406019944,1.1579090527915954,74.22400000488281,92.93800008789063
47,3.578922601846548,1.1946468775939942,74.06000000488281,92.89200006591797
48,3.6083018688055186,1.260998232460022,74.14000005859376,92.86400019775391
49,3.5781111992322483,1.1498024686813355,74.71200008789063,93.00600000976563
50,3.556312166727506,1.2316087501144408,74.506,93.2479999584961
51,3.5926481760465183,1.145808798084259,74.93599995117188,93.1839999609375
52,3.5503257513046265,1.2279409123039247,74.85000013183594,92.91600014160156
53,3.548908096093398,1.1786081212425232,75.20799995361328,93.3060000390625
54,3.5557837944764357,1.2336573403549194,74.51599998291016,93.11999993896484
55,3.4652767731593204,1.1627975540542603,75.40400003417969,93.35799999267579
56,3.479139502231891,1.138182176322937,75.28800010986328,93.4080001171875
57,3.5164922384115367,1.1406619202041626,75.22600006103515,93.27399998779296
58,3.501664170852074,1.1390633061408997,75.40600000732422,93.3199999633789
59,3.491689085960388,1.1517196659469604,75.43400013427734,93.35000009521484
60,3.5151750307816725,1.1642712952232361,75.45400010986329,93.34399998535156
61,3.4816090877239523,1.1658449925994874,75.48800003173828,93.6820001171875
62,3.4707261690726647,1.1238778091812134,75.67400008300781,93.52800009033203
63,3.4746356377234826,1.147294167804718,75.82200003417968,93.59600001220703
64,3.459448658503019,1.174054825153351,75.43000003417968,93.5079999609375
65,3.4489993590575,1.1662658145713807,75.70599995605468,93.57999990966798
66,3.485427141189575,1.1435979705429078,75.92599992675781,93.64600016601563
67,3.4828745401822605,1.134217993068695,76.18200002929687,93.7379999560547
68,3.45894560447106,1.156855950126648,75.72199998046875,93.55200006591797
69,3.44607728261214,1.0945456811714172,76.3340001611328,93.77799993408203
70,3.477376791147085,1.1079738793563843,76.24000010986327,93.91200016601563
71,3.5141904812592726,1.0928245606613158,76.30200008789062,93.69600013916016
72,3.410055022973281,1.0860164308166504,76.26800003173828,93.90199995849609
73,3.420612326035133,1.116220949783325,76.65600010498046,93.8940001171875
74,3.4254535344930797,1.1075060968399049,76.16600008300782,93.7419999584961
75,3.4103054358409,1.1158017666244506,76.32199995361329,93.9019998803711
76,3.476789043499873,1.0348516142463684,76.68800013671876,94.00600001220702
77,3.4337857228059034,1.0966742949867248,76.42800005371093,93.84000006103516
78,3.4188507887033315,1.0247036343955993,76.71600001464844,94.09599998779296
79,3.423209318747887,1.0942145678901671,76.81400018554687,94.01599993408203
80,3.3477271520174465,1.0325566992378234,76.62600013427735,93.98600003662109
81,3.4215016089952908,1.046407821083069,76.63400003662109,94.1180000366211
82,3.430701503386864,1.0959824050521851,77.01999987304687,94.0459999609375
83,3.4161557325950036,1.0931310430908203,76.86399993164062,93.97000000732422
84,3.3875768918257494,1.108051809272766,76.73400020996094,94.07200014160156
85,3.4196390096957865,1.14773931640625,76.86000002929687,94.1320000366211
86,3.437829402776865,1.0369087854003907,77.00200010498047,94.10800000976562
87,3.385591616997352,1.0996613207626342,77.08599995361328,94.2460000366211
88,3.4244728546876173,1.062659384174347,77.17800010742188,94.22800000976562
89,3.388879519242507,1.0323741394615173,77.2419999243164,94.26999990966797
90,3.378705382347107,1.1156270475387573,77.27599994628906,94.21800003662109
91,3.33666560283074,1.0247049710083007,77.25200018066407,94.21600001220703
92,3.386662767483638,1.0646427408790589,77.43999997802734,94.20000000976563
93,3.341835297071017,1.063674129047394,77.41000015625,94.33999998291016
94,3.3956225651961107,1.0314554983901978,77.34999997558593,94.32600013916016
95,3.338684366299556,1.1056486014938354,77.28800005371093,94.18600016357422
96,3.328054749048673,1.0338896635437012,77.83000005371093,94.4480000366211
97,3.3632744550704956,1.0532264839172363,77.81800002929687,94.45999998291016
98,3.3554288148880005,1.0673184670639038,77.83000002197265,94.48799992919922
99,3.3159962709133444,1.1167875634765625,77.59199997070313,94.43600013916016
100,3.3194725880256066,1.036779229888916,77.53200012939453,94.40800008789063
101,3.3196748128304114,0.9950653231811524,77.88199995605468,94.39600003662109
102,3.294141851938688,1.0245422155380248,78.03399997070312,94.53000011230469
103,3.3244126851742086,1.0557471356391908,77.7660000024414,94.47399998535157
104,3.3125311044546275,1.0209960681152344,78.218,94.63000006103516
105,3.3134930042120128,1.034874595451355,77.87000005371094,94.63400013916015
106,3.3416621409929714,1.022444501953125,77.88000011230469,94.48000003662109
107,3.2982225234691915,1.0652961749267578,77.80800005371094,94.52599993164063
108,3.2699141318981466,0.9846357432556152,78.00199992675782,94.64599993408203
109,3.334826551950895,1.0435322013473511,78.16800010253907,94.58799998291016
110,3.303573865156907,1.0032002863311769,78.34800002929687,94.69000000976563
111,3.3007822862038245,1.0601369972038268,78.02399995117187,94.53400008544922
112,3.2984123138281016,1.051457786140442,78.11000013183593,94.75999990478516
113,3.262780079474816,1.1014147911262513,78.36200005859375,94.7040000366211
114,3.28412912442134,1.0476678370285035,78.480000078125,94.75799998535156
115,3.240434399017921,1.0324708117103576,78.42200002929688,94.76400003417969
116,3.2967760287798367,0.9791300971794128,78.43800000244141,94.78400003417968
117,3.261618274908799,1.051374360847473,78.38000015625,94.84999998535156
118,3.2742426487115712,1.0074605080032348,78.4679999975586,94.82000006103516
119,3.276140946608323,1.0688135836410522,78.35999984375,94.60799990478516
120,3.2416304624997654,0.996600732421875,78.74800010498046,94.91000008544921
121,3.2056774359482985,1.0518232615089416,78.67,94.78000000976563
122,3.227355828652015,0.9713950177574158,78.79799997558594,94.97000006103515
123,3.2498799654153676,1.0024585261154175,78.842,94.94400006103515
124,3.246049715922429,0.9771295314407349,78.96599997314453,95.06600011230469
125,3.261179988200848,1.0789605331230163,78.8920000756836,95.08599990478515
126,3.254664714519794,1.004692148284912,79.15199997558594,95.04799998291016
127,3.2356422589375424,0.9843215970993042,78.88800013427735,94.97000003662109
128,3.2246823494250956,1.0390974626731873,79.14800004882812,95.04000006103516
129,3.2226843650524435,1.0247631069946288,79.1079999194336,94.97200000976562
130,3.2466600766548743,0.9585929667854309,79.3819999194336,95.14600006103515
131,3.242083659538856,0.9486254569244384,79.40000020507813,95.22400008544922
132,3.2279451351899366,0.9887630248260498,79.13200005126953,95.02400000732422
133,3.220690608024597,1.0240418742561341,79.25200007324219,95.21199990478516
134,3.218092579108018,0.9779359574508667,79.25000002441406,95.21999998291015
135,3.2125297693105845,0.9315142623901367,79.20600002441407,95.17600013671876
136,3.239817573474004,1.0313529789352418,79.226000078125,95.06800013916016
137,3.18045646410722,0.9608411916351318,79.4799999243164,95.29999993164063
138,3.1873541336793165,0.9543602143478394,79.43999997558593,95.18199998291016
139,3.165460522358234,1.048680834980011,79.42599994873046,95.1680000366211
140,3.1896323424119215,1.0906442671966552,79.2520000805664,95.24000000976562
141,3.1748633751502404,0.9517472584152221,79.57200015136719,95.26599995605469
142,3.1773726298258853,0.9701159319877625,79.45400012695312,95.21600000732421
143,3.159519048837515,1.0661436851692199,79.51800006835937,95.27800010986329
144,3.2158747819753795,0.9745243270874023,79.45800002685547,95.29599998291016
145,3.161927204865676,0.9890392238426209,79.53000012695313,95.23800000732422
146,3.1869288132740903,0.9883130007171631,79.784000078125,95.34600013916015
147,3.1633920211058397,0.9936535754013062,79.90000004882812,95.43400000732422
148,3.190845801280095,0.9889785444259643,79.71599997070312,95.39599990478516
149,3.1653112173080444,1.0283795692443847,79.9499999194336,95.36399992919922
150,3.166087508201599,0.9779253969573974,79.97,95.42999995605469
151,3.1479290081904483,1.0098531292724608,79.94000005371093,95.36200000732421
152,3.1536225997484646,0.9446434412956238,80.28599997070313,95.48000000732422
153,3.110550990471473,0.9584104975318909,79.865999921875,95.52199998046875
154,3.167463806959299,0.997737165107727,79.82200020751954,95.42799998291015
155,3.13789662031027,0.9710854097938537,80.03800002685547,95.44799998046875
156,3.1173960062173696,0.9749535838317871,80.34599997070312,95.50200000732421
157,3.1322424686872044,0.9212129807472229,80.262,95.50399992919922
158,3.1068271673642673,0.9729003804779053,80.38200004882812,95.52800018798828
159,3.1096662007845364,0.9109539186859131,80.596000078125,95.70000005859374
160,3.103513864370493,0.9322344965553284,80.41200009765625,95.57599998291016
161,3.132872122984666,0.882158300819397,80.38400020996093,95.57800003417968
162,3.1003327461389394,1.0276759369468689,80.59400004638672,95.57400000732422
163,3.107926056935237,0.9502820700454712,80.40000017578124,95.60600003417969
164,3.0697899781740627,0.9290867798805237,80.49800009765625,95.72800008544922
165,3.111760973930359,0.9023137127494812,80.51399999511719,95.69000000732422
166,3.0767463353964,0.9757205354499817,80.75000004638672,95.75400008544922
167,3.0687196346429677,0.8881613404273987,80.59600015625,95.73599998046875
168,3.1184717875260572,0.8646588974380494,80.87400010253906,95.79600013916016
169,3.090632071861854,0.9433945862197876,80.9039999975586,95.75199995605469
170,3.115909457206726,0.931336729221344,80.72200017578125,95.74200016357422
171,3.0933127036461463,0.9939709894943237,80.75599999755859,95.78200003417969
172,3.086327782044044,0.8836379608726501,80.86799991455078,95.96399992919922
173,3.077766941143916,0.9232083583641052,80.89999999511718,95.77400003662109
174,3.0573448768028846,0.9307572483062744,80.80800012695312,95.82199990478516
175,3.0762509199289174,0.8996614429855346,81.23000007568359,95.8099998803711
176,3.0396402340668898,0.9776417551994324,81.1700001220703,95.93000000488281
177,3.03109754048861,0.9624793661499024,81.14000007080078,95.93800010986328
178,3.040617612692026,0.9664762647247315,81.03800004882812,95.83199998291016
179,3.0303638715010424,0.894894326210022,81.37600004638672,96.00599992919922
180,3.009007288859441,0.9451599026298523,81.19999999511718,96.01800003417969
181,3.0288018721800585,0.9100659669685364,81.40999994140626,95.98600013671874
182,3.003125685911912,0.887927452545166,81.27000020019531,96.02800003173829
183,3.008813509574303,0.869329244632721,81.37200002685547,96.01800000732422
184,3.0261400571236243,0.9118497411537171,81.52200002197266,96.00800010986327
185,3.004841850354121,0.9412287125587463,81.54200007080078,95.99000010986327
186,3.013606933447031,0.8910019776153565,81.52200012451172,96.04800006103515
187,2.9876741079183726,0.923979015827179,81.59199994384765,96.07599998046875
188,2.9696753850350013,0.9355228743362427,81.60599996826171,96.01600008544922
189,2.954044461250305,0.9194323031997681,81.5740000756836,96.01999998291015
190,3.001693844795227,0.9340647240066529,81.67800007324219,96.07999998291015
191,2.929094571333665,0.8478024571037293,81.85800018066406,96.16800000976562
192,2.990995086156405,0.916919107837677,81.76200004882813,96.18799995605468
193,2.976861696976882,0.9573249016571045,81.62800002197265,96.06000005859374
194,2.9593937397003174,0.8564995549583435,81.91399994140625,96.27199992919923
195,2.941809947674091,0.8975905869483948,81.79000001953125,96.19000003173828
196,2.976395744543809,0.9395446335983276,81.95199991455078,96.15799998291016
197,2.9350195206128635,0.9216514646148681,81.84800004638672,96.24800013671874
198,2.942404765349168,0.833588249835968,82.15200007568359,96.24200003173829
199,2.9270170743648825,0.8581027887535095,82.06399997070312,96.26400005859375
200,2.9117765701734104,0.8230641770172119,81.91799999267577,96.30999992919922
201,2.9535570328052225,0.932054790763855,82.19600002197265,96.28800008544921
202,2.951778081747202,0.8631351574707031,82.04400009277344,96.26199992919922
203,2.8975991835960975,0.9194772821807862,82.2759999975586,96.3100000341797
204,2.9302108104412374,0.8795070449256897,82.24599999511719,96.35400000732422
205,2.88518776343419,0.8075243322563171,82.11200001953125,96.42400003173829
206,2.8956097731223474,0.9984648355102539,82.22799999267578,96.40400000488282
207,2.893759452379667,0.8589856274986267,82.30600004882812,96.39799992919922
208,2.9317749371895423,0.8957522449111939,82.28000004882813,96.35800003173829
209,2.8650061717400184,0.8618767699813843,82.32400001953125,96.39999992919923
210,2.88412383886484,0.8900039534378051,82.3660000439453,96.38799995605468
211,2.8546677277638364,0.8119068654060364,82.51199999755859,96.43000005615234
212,2.89072257738847,0.8770508949661255,82.59799999511719,96.38600005615234
213,2.890521159538856,0.8636206825447083,82.42000004638672,96.48800003173828
214,2.9142217819507303,0.8873085660552978,82.58000004394532,96.38800000732422
215,2.890397998002859,0.8559163021659851,82.7499999951172,96.45200003417969
216,2.858252213551448,0.8466270779418945,82.59599989257812,96.52399992919922
217,2.8843393417505117,0.9059183154869079,82.61600007080078,96.5360000830078
218,2.8276894367658176,0.8348352497673035,82.66399994140625,96.40600010986329
219,2.8219637045493493,0.8524979174041748,82.88600014892579,96.44799992919921
220,2.8488004757807803,0.877018903388977,82.80199997070312,96.53199995361328
221,2.8460423212784987,0.8800382892227173,82.90000007324218,96.49599998291016
222,2.8185073137283325,0.8374453451538086,82.75200012695312,96.53399998046875
223,2.8086545192278347,0.8324967349433899,82.94800004638672,96.54000008544922
224,2.8704542838610134,0.8967138567352295,82.99800004638672,96.60400000732422
225,2.806610419200017,0.8318240044593811,83.09800009521484,96.58800008544922
226,2.8394719453958364,0.87683580160141,83.16000001953125,96.53800003173828
227,2.8203150309049168,0.8570964310073853,83.18600002197266,96.55399992919922
228,2.834619778853196,0.7959771487045288,83.12600009521485,96.64399992919923
229,2.76179917042072,0.8651750437164306,83.08200004882812,96.58200003173827
230,2.79124588232774,0.8245219149398804,83.05600001953125,96.70399998046875
231,2.8034701347351074,0.8726015359115601,83.24400012207032,96.65399992919922
232,2.747335984156682,0.8783034913063049,83.432,96.67800008544921
233,2.7719013415850124,0.8303304001617432,83.30800009765625,96.63400003417969
234,2.7671942710876465,0.8046093546676636,83.38800022949219,96.72800000732421
235,2.7805707179583035,0.8224680534934997,83.34000001953125,96.63600005859375
236,2.7180897180850687,0.8442720851898193,83.31200017578125,96.66800010986329
237,2.7626740840765147,0.8542504468154907,83.55999991699218,96.74800010986328
238,2.7336452649189877,0.7994218709564209,83.48200022460938,96.75800000732421
239,2.7233040791291456,0.8043592370414734,83.36799997070312,96.75200008544923
240,2.6998801781580997,0.8452410479164123,83.50599999267578,96.73600010986328
241,2.73690388752864,0.8795205630493164,83.41199997070312,96.68199995605468
242,2.749963100139911,0.8602245480918884,83.54400004882812,96.77599995605469
243,2.698979405256418,0.8548383228492736,83.53200005126953,96.78000008544922
244,2.72799957715548,0.7891766788101197,83.61400001953125,96.73800008544922
245,2.7299381494522095,0.7910343458747864,83.68000006835938,96.82400010986328
246,2.701122834132268,0.8022129875183105,83.64600004638672,96.83600010986328
247,2.724263282922598,0.8257709519386291,83.74200012451172,96.81400003173827
248,2.700051261828496,0.8478059650039673,83.89000007324219,96.79600013671875
249,2.7121142332370463,0.8197229961776733,83.67799994140626,96.79600003417968
250,2.6754928001990685,0.807203971862793,83.82799999511718,96.82600000732423
251,2.6727003134213962,0.8816155452919007,83.89399997070312,96.84000021484376
252,2.678276859796964,0.7926013770103455,83.91199999267577,96.87600008544922
253,2.693314579816965,0.8830190794181824,83.71199999511718,96.84400008544922
254,2.7000538385831394,0.7949762279510498,83.96000001953125,96.87799992919922
255,2.6520780416635366,0.8772449176788331,83.88600017578125,96.84400008544922
256,2.6728187249257016,0.803320650062561,83.99600002197266,96.90200008544922
257,2.6415497614787173,0.8764836074638367,83.93999999511719,96.85999992919922
258,2.656065143071688,0.9068516694831849,84.12800015136719,96.85599998046875
259,2.6415039667716393,0.7998114633750916,84.04200009765626,96.88399998046874
260,2.6295432585936327,0.8035087095451355,84.12000022949219,96.86799992919921
261,2.6302762306653538,0.7987516967964172,84.03600001953124,96.87400000732421
262,2.6307426599355845,0.8718977772903442,84.01400010009766,96.88399998046874
263,2.646725838000958,0.8528344987106323,84.08799999267578,96.87399998046875
264,2.6446866530638475,0.8193294952964782,84.15799994140625,96.85400010986328
265,2.6704056446368876,0.8128056451225281,84.08600010009765,96.87400003173828
266,2.6277802265607395,0.8045233806419373,84.13200004638672,96.92999998046875
267,2.627194743890029,0.8470412642097473,84.22599999511719,96.84400008544922
268,2.6251884607168345,0.8471642654037476,84.24400010009765,96.92999990478516
269,2.6308794296704807,0.8462172267532349,84.2559999951172,96.91200005859375
270,2.602825742501479,0.8418390936470032,84.20400001953125,96.93599995361328
271,2.614965172914358,0.8266842465209961,84.29800010009765,96.98600003417968
272,2.6357730627059937,0.7932263405036927,84.19200004882812,96.99600000732421
273,2.6102960935005775,0.7815226528167725,84.29200010009765,96.94800003173827
274,2.581689577836257,0.8857900485992432,84.3839999951172,96.94000000732422
275,2.5989666902101956,0.8270471384239196,84.38600009765625,96.99800000732422
276,2.5846868386635413,0.7947231995010376,84.32600004882812,96.97799998046875
277,2.6171456300295315,0.8640070059013367,84.38799999511718,96.96999995605469
278,2.588435182204613,0.8077503689575195,84.31600012695313,96.95199992919922
279,2.625655614412748,0.8132058056640625,84.34200007324219,96.95599992919922
280,2.561418294906616,0.8367598581314087,84.33400010009765,96.99399998046874
281,2.5751408430246205,0.8749884683609008,84.37200007324219,96.97000003417969
282,2.5804577515675473,0.8383332873916626,84.42400009765625,96.94399998291016
283,2.6351821697675266,0.827440677433014,84.4620000732422,96.93799995361329
284,2.609296101790208,0.802014238834381,84.53200015136719,96.99599998046875
285,2.580980291733375,0.8811971377563477,84.35800017578126,96.90999998046875
286,2.612268191117507,0.7776840033721923,84.43000015136718,96.95199992919922
287,2.5652044919820933,0.8067414920043945,84.48200007324219,96.94799998046875
288,2.5902458887833815,0.86946218957901,84.34200015136719,96.94399992919922
289,2.5916363367667565,0.8006659157180787,84.45599999511718,96.94400005859374
290,2.5842091762102566,0.808074400806427,84.36799999511719,96.94000005859375
291,2.5799223551383386,0.7949452325820923,84.49000018066407,96.98000008544922
292,2.598302180950458,0.8045749850654602,84.56400007324218,96.95600000732422
293,2.5514029356149526,0.8555887586593628,84.42600004882813,96.92000008544922
294,2.570864484860347,0.7697095928192139,84.5159999975586,96.95799992919922
295,2.574796667465797,0.7760678444099426,84.45000004882813,96.97799998046875
296,2.5906164095951962,0.8296054859733581,84.56399999511719,96.93799995605468
297,2.5699487374379086,0.8084378022956848,84.59800007324219,96.95400003417969
298,2.589944940346938,0.8603264539909363,84.50600010009765,96.96199995605468
299,2.551648791019733,0.753419985370636,84.51800007324219,96.97800003173828
300,2.600609458409823,0.9259959189987182,84.56600012451172,96.93399992919922
301,2.557686319717994,0.7893291508483886,84.56200004882812,96.99399992919922
302,2.591786503791809,0.8664678469276428,84.50399997070312,96.98600003173829
303,2.5873728715456448,0.8108873423957824,84.55000010009766,96.95000003173828
304,2.571756307895367,0.8550376346588134,84.50399997070312,96.98800000732422
305,2.5680193167466383,0.8534311754226684,84.46400020507812,96.96399992919922
306,2.549737260891841,0.7815708198738098,84.44999999511718,96.97400003173829
307,2.5860844942239614,0.8156077337646485,84.49400007324219,96.96799990478516
308,2.5601435899734497,0.7796311644172669,84.52599999511719,96.99600003173828
309,2.5672685458109927,0.8047046170425415,84.56000004882813,97.00199992919921
================================================
FILE: checkpoint/iformer_large/args.yaml
================================================
aa: rand-m9-mstd0.5-inc1
amp: false
apex_amp: true
aug_repeats: 3
aug_splits: 0
batch_size: 64
bce_loss: false
bce_target_thresh: null
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 1
class_map: ''
clip_grad: 1.0
clip_mode: norm
color_jitter: 0.4
cooldown_epochs: 10
crop_pct: null
cutmix: 1.0
cutmix_minmax: null
data_dir: /dataset/imagenet-raw
dataset: ''
dataset_download: false
decay_epochs: 30.0
decay_rate: 0.1
dist_bn: reduce
drop: 0.0
drop_block: null
drop_connect: null
drop_path: 0.5
embed_dim: 384
epoch_repeats: 0.0
epochs: 300
eval_metric: top1
experiment: iformer_large
gp: null
hflip: 0.5
img_size: 224
initial_checkpoint: ''
input_size: null
interpolation: ''
jsd_loss: false
local_rank: 0
log_interval: 50
log_wandb: false
lr: 0.001
lr_cycle_decay: 0.5
lr_cycle_limit: 1
lr_cycle_mul: 1.0
lr_k_decay: 1.0
lr_noise: null
lr_noise_pct: 0.67
lr_noise_std: 1.0
mean: null
min_lr: 1.0e-05
mixup: 0.8
mixup_mode: batch
mixup_off_epoch: 0
mixup_prob: 1.0
mixup_switch_prob: 0.5
model: iformer_large
model_ema: false
model_ema_decay: 0.9998
model_ema_force_cpu: false
momentum: 0.9
native_amp: false
no_aug: false
no_ddp_bb: false
no_prefetcher: true
no_resume_opt: false
num_classes: null
opt: adamw
opt_betas: null
opt_eps: 1.0e-08
output: checkpoint
patience_epochs: 10
pin_mem: false
port: '25500'
pretrained: false
ratio:
- 0.75
- 1.3333333333333333
recount: 1
recovery_interval: 0
remode: pixel
reprob: 0.25
resplit: false
resume: ''
save_images: false
scale:
- 0.08
- 1.0
sched: cosine
seed: 42
smoothing: 0.1
split_bn: false
start_epoch: null
std: null
sync_bn: false
torchscript: false
train_interpolation: random
train_split: train
tta: 0
use_multi_epochs_loader: false
val_split: validation
validation_batch_size: null
vflip: 0.0
warmup_epochs: 5
warmup_lr: 1.0e-06
weight_decay: 0.05
worker_seeding: all
workers: 10
================================================
FILE: checkpoint/iformer_large/summary.csv
================================================
epoch,train_loss,eval_loss,eval_top1,eval_top5
0,6.910546504534208,6.850864137573242,0.4499999999666214,1.8380000002288819
1,6.643188366523156,5.693355998840332,4.978000012207032,14.922000020751954
2,6.290052909117478,4.7761042728424075,12.611999971923828,30.79799999633789
3,5.945564820216252,3.82945441116333,23.45400005126953,47.374000043945315
4,5.6793032976297235,3.3477501653289794,30.48999998779297,57.0860001171875
5,5.384508353013259,3.0526050999450685,36.21600011962891,63.426000043945315
6,5.250409749838022,2.7507688496017457,41.93400006347656,68.9760001977539
7,5.041767652218159,2.462726919937134,46.78800009277344,73.27800002929688
8,5.04708271760207,2.4434100090026853,47.634000017089846,74.30800002441406
9,4.975198323910053,2.42360043762207,48.37999992431641,75.1100000048828
10,4.829011660355788,2.2661094271850586,50.68400001708984,76.53200008300782
11,4.770912170410156,2.0769049119567873,54.78399994873047,80.32799993896484
12,4.602100574053251,1.9912925200653075,56.97400007324219,81.9780001147461
13,4.475416091772226,1.903586951904297,58.506000017089846,83.21399985107422
14,4.387893658417922,1.738180983314514,61.67000002929687,85.12800008300782
15,4.2983089960538425,1.6956268508148193,62.94399990478516,85.9460001586914
16,4.295326141210703,1.634114217414856,63.83200000732422,86.87800005615235
17,4.199971575003404,1.5565200023078918,65.03799997558593,87.56400002441406
18,4.11814390696012,1.5510349703979491,65.20400005126953,88.04800012451172
19,4.085179814925561,1.5207680701065063,66.44000000732422,88.56600005126953
20,4.049802477543171,1.4801421740722656,66.96200008056641,88.81599999755859
21,4.023964524269104,1.5004107713890076,68.018,89.53000020507812
22,3.941151738166809,1.4183791262245178,68.2960000805664,89.44400018066406
23,3.92056103853079,1.4148408610725403,68.98600007324218,90.14000014892578
24,3.8803509015303392,1.362926469554901,69.38000001953125,90.02999994384766
25,3.885836592087379,1.3796905069923402,69.72199999267578,90.50999999267579
26,3.826949018698472,1.3466479450416564,70.56800001220704,90.72599996826172
27,3.807827426837041,1.294949890460968,70.44800017578125,91.05599999267578
28,3.7755139882747946,1.324147945137024,71.0500000805664,91.0999999975586
29,3.7466984620461097,1.290724811515808,71.66600008789062,91.4300000415039
30,3.7509118685355554,1.2814777055740356,71.91000001708984,91.6260000415039
31,3.737801056641799,1.2778163366699218,72.21599998535156,91.6360001196289
32,3.710351595511803,1.2349367953300476,72.29000001708984,91.89800001464843
33,3.7168827790480394,1.2245775287628173,72.55399996582031,92.18000006347657
34,3.6832274198532104,1.2634735249519349,72.36800007080078,92.01200004882813
35,3.6934918073507457,1.2468857565307616,72.46200011474609,92.12200006591797
36,3.6704082672412577,1.2496089472961425,72.92200000732421,92.40000008789063
37,3.61841528232281,1.1793549740600586,73.54600006835938,92.4979999633789
38,3.6186260260068455,1.176423654499054,73.8760000415039,92.71800011962891
39,3.6086827516555786,1.1873519218063355,73.72600017089843,92.56600004150391
40,3.6681673526763916,1.1816426961135864,73.89600013916015,92.75600006591797
41,3.5902654574467587,1.1312151807022095,73.84999998291016,92.69200009033203
42,3.584832870043241,1.1513134127235412,74.2380001953125,92.85600009033203
43,3.5984942454558153,1.1854713468933105,73.94800005859375,92.85000013916016
44,3.6091530047930203,1.1779234772491456,74.45200006103515,93.05000013916016
45,3.485910278100234,1.124913458557129,74.51200013916015,92.87999999023438
46,3.53954606789809,1.1186642870903014,74.91000016113281,93.23799998291015
47,3.5389673618169932,1.1638443058395387,74.89200000976562,93.29199998779296
48,3.581275013776926,1.1198174129295349,74.83599997558593,93.1000000366211
49,3.550371078344492,1.124014576892853,75.47600002929687,93.30400001220703
50,3.518071541419396,1.116720790939331,75.13999995117187,93.46999998535156
51,3.544528530194209,1.1221498290252685,75.35199995605468,93.39000006591797
52,3.5305112233528724,1.153511095676422,75.22399997558594,93.4659998803711
53,3.512775265253507,1.1154144276809692,75.54600013427735,93.59000001220703
54,3.5236593118080726,1.1202192064666747,75.3940000390625,93.45000014404297
55,3.433111392534696,1.096172775325775,75.87800006347656,93.67000001708985
56,3.4627847304710975,1.1273769705581664,75.95400010986329,93.6540000366211
57,3.4790163223560038,1.0841456203651427,76.21400003173828,93.7280001928711
58,3.469840278992286,1.1170489598274231,75.92600005615235,93.62200009033204
59,3.4525169225839467,1.0933158408927917,76.21000011230468,93.7460000415039
60,3.4901640873688917,1.126006812171936,76.19000008300782,93.78400006347657
61,3.4491439507557797,1.0970030860710145,76.21999998535156,93.81000001220703
62,3.432765245437622,1.102595337715149,76.3400001611328,93.96200000976563
63,3.4307719469070435,1.0557644449996948,76.3160000341797,93.7779999609375
64,3.4119530549416175,1.0703777256584168,76.16200011230468,93.74999993408203
65,3.4012008630312405,1.0802638423919677,76.45600000732422,93.7179998803711
66,3.4363037989689755,1.0589085584640503,76.77399997558594,94.11200008544922
67,3.442072950876676,1.079272590522766,76.80000013183594,94.03599990478516
68,3.4058668430034933,1.0758214369010926,76.81999989990234,94.10399990966796
69,3.420797137113718,1.0494617863845825,76.91000010986328,94.1500000366211
70,3.425681655223553,1.0563462660980225,76.71800008789063,94.0239998803711
71,3.4518049221772413,1.0740118409538268,76.88800006103516,94.09999998535156
72,3.3764912531926083,1.0612143629455566,77.09400005859375,94.22000000976563
73,3.396061952297504,1.0656421076011657,76.84600000488281,94.17600008789063
74,3.377353860781743,1.038946965942383,77.39399997802734,94.1700001928711
75,3.3845853255345273,1.0483537149047852,77.19200008544922,94.14399993408203
76,3.4312566060286303,1.0124181773376464,77.10600003417969,94.2560001147461
77,3.3898003651545596,1.0355482028388978,77.07400000488282,94.21800013916015
78,3.3775642835176907,1.0285578982925414,76.92400008789062,94.15399995849609
79,3.407664785018334,1.04491987159729,77.33000002685547,94.2539999584961
80,3.323443284401527,1.0443046270751952,77.40400016113281,94.26999998535156
81,3.3806505845143247,1.0459676809692382,77.52599987792969,94.2600000366211
82,3.382426720399123,1.042313983669281,77.59600008300781,94.29800001220703
83,3.3965557446846595,1.0520244059181214,77.6939999243164,94.46400005615234
84,3.3593188799344578,1.0567831263923646,77.55600005615234,94.34799995849609
85,3.3830844530692468,1.0939044120025634,77.37399982910156,94.3300001147461
86,3.3928306561249952,1.053267989654541,77.36200018554688,94.33200006347656
87,3.3483442343198337,1.03019209690094,77.79400000976563,94.4760000390625
88,3.400100533778851,1.0476388161849977,77.95800003173828,94.43800006103515
89,3.363451132407555,1.03560297996521,77.67800000488282,94.4600001171875
90,3.356158916766827,1.0289478774261474,77.82800010498048,94.31199995849609
91,3.3031399066631613,1.0243545026016236,78.07000008789062,94.37400001220703
92,3.340137399159945,1.0346087313461303,78.11400003173829,94.68599998291016
93,3.3096237182617188,1.0451400087165832,77.87400008300781,94.53799998291015
94,3.3592547178268433,1.0183623178482055,78.16999997314453,94.54400021728516
95,3.3121973734635572,1.034790303516388,78.08200010742188,94.58400011474609
96,3.285599561838003,1.0314845135879516,78.20800007324219,94.63800003417968
97,3.3367106272624087,1.027085668697357,78.33600010253906,94.68800010986328
98,3.3270413325383115,1.0407975508117675,78.17400010253907,94.70200016357421
99,3.2877965707045336,1.0243351782989503,78.14800010498047,94.69000021728516
100,3.2988253831863403,0.9814645049095154,78.23000010498046,94.48999995605469
101,3.2901770885174093,1.0053197279930115,78.35000014160157,94.68800014404297
102,3.2781461018782396,1.0209910627746581,78.3820000805664,94.74400003662109
103,3.303452546779926,1.0421659322166443,78.30199997802734,94.71599993408203
104,3.291901891048138,1.010951399459839,78.26399998046875,94.74600000976562
105,3.3149656699253964,1.0163331426811217,78.25600008544922,94.77600006103516
106,3.308741569519043,1.0146123032188417,78.522000078125,94.77400008544922
107,3.293911713820237,1.0555380408668518,78.21600012939453,94.7800000341797
108,3.2435123920440674,0.9968828507232667,78.66600003173828,94.75200008544923
109,3.3232491658284116,1.0174081901168823,78.465999921875,94.74800008544922
110,3.2690613820002627,1.0153419143104554,78.45600010742187,94.82400006103515
111,3.2709354712412906,0.9924267137527466,78.67800013183594,94.75800001220703
112,3.2855274860675516,1.0057064008140564,78.87199997314453,94.99600000732421
113,3.230813585794889,1.0080310935401917,78.80600005615234,94.86400006103516
114,3.2722094242389383,1.0076740469360352,78.86400005371094,95.06999998535156
115,3.22453341117272,0.9959960584068298,78.82800008056641,94.95199995605469
116,3.2814402763660135,0.9519607814407348,78.487999921875,94.87799998535156
117,3.2548026855175314,1.0326564323997498,78.628000078125,94.9060000366211
118,3.2524013702686014,1.0131952097511292,78.76800005126952,95.01200003417969
119,3.2414560868189883,1.0325893316078185,78.83600013183593,95.07600000976562
120,3.2432524240933933,0.9700666660690308,79.02600014892577,94.98200008544922
121,3.2066384187111487,1.0066341878509522,79.138000078125,94.99800003417968
122,3.212906388136057,0.9771206850433349,79.23599997558594,95.18600000732422
123,3.2288889243052554,0.9860764531326294,79.150000234375,95.0760000390625
124,3.227351968105023,0.987032350063324,79.35400015625,95.12600000732422
125,3.229693678709177,0.979343092956543,79.24400004882813,95.12000013916015
126,3.239646599842952,0.9877924637985229,79.15600005126953,95.2240001123047
127,3.2091639592097354,0.9348715629386902,79.27000005371094,95.20400013916016
128,3.2042169387523947,1.0053931137657166,79.28400007324218,95.2520000366211
129,3.2251157577221212,1.0165403331756593,79.19800013183594,95.1359999584961
130,3.2341381861613345,0.9704256143569946,79.396,95.2599999584961
131,3.2344076266655555,0.9606916565704345,79.44599995361328,95.26999990478515
132,3.2237201837392955,0.9550872252464294,79.938,95.31600000976563
133,3.2206397239978495,0.9822869607162475,79.31199984375,95.14399998291016
134,3.190979370704064,0.9883818350982666,79.44399992675781,95.17000013916015
135,3.2153615951538086,0.9423513364982605,79.59600007568359,95.31000003662109
136,3.2060717894480777,0.9656000201225281,79.69599997314454,95.30999998291016
137,3.169634094605079,0.9717180125236511,79.84799997558594,95.30599998291015
138,3.1719662959759054,0.9458955500030518,79.91200010498046,95.36600000732422
139,3.1517449158888597,0.9693083749580383,79.68799994628907,95.28599998291016
140,3.175395580438467,0.9482538063621521,79.91800015869141,95.36599995849609
141,3.169075736632714,0.9496471033668518,79.8199999194336,95.31599998291016
142,3.162912818101736,0.9770244847106934,79.74799994873047,95.29200003662109
143,3.149576177963844,0.9529659438323974,80.12799994140624,95.37199998291015
144,3.1874825679338894,0.9519318865776062,79.84200002441406,95.37400008789062
145,3.1350886271550107,0.9400943297576905,79.93600007324218,95.38200006103516
146,3.1708294336612406,0.9612622356414795,79.9659999243164,95.4799999609375
147,3.155010828605065,0.9531323616218567,80.27799997558594,95.48200008544922
148,3.172439987842853,0.9382277800369263,80.16600015625,95.53000011230469
149,3.1274655048663798,0.9533389562416077,80.13000002441406,95.55600008789062
150,3.1443814772825975,0.9371918731307983,80.33399997070312,95.56000006103515
151,3.1267175674438477,0.9546587055587769,80.41200012695313,95.64400006103516
152,3.153881384776189,0.9710652282524109,80.404000078125,95.69400000732422
153,3.0826464799734263,0.9507630409240723,80.43200005126953,95.59800003417969
154,3.15419864654541,0.9281492255783081,80.31200010009766,95.61800016357422
155,3.1137414620472836,0.9433984814453125,80.35400004882813,95.52200008544922
156,3.0899893045425415,0.9258719791603088,80.62000002197266,95.58799998291016
157,3.10144082399515,0.9700759565734863,80.70800010009765,95.55399992919922
158,3.0753163374387302,0.9230545413589477,80.6819999243164,95.64000016357421
159,3.0889107905901394,0.9027757112121582,81.10200007568359,95.81000018798828
160,3.0785955374057474,0.9065625233840943,80.93000007324218,95.75000000732422
161,3.107429366845351,0.9389520249366761,80.83800002197266,95.78800006103516
162,3.0843942990669837,0.9132129728317261,81.03600018066406,95.69800008544922
163,3.0945202020498424,0.9214211973762512,80.93000001953125,95.87799995361328
164,3.054575278208806,0.9130185977935791,80.78400002197266,95.78399992919923
165,3.0874763360390296,0.9020014848709107,80.94199999511719,95.86800013671875
166,3.0821682856633115,0.9138980267524719,81.06599997070313,95.79600006103516
167,3.040875664124122,0.9022171401786804,80.85800012695313,95.83800005859375
168,3.091547810114347,0.9082347438430786,80.89400007324218,95.83400016357422
169,3.100938613598163,0.9023391273498536,81.07800004882813,95.79400008544921
170,3.096641127879803,0.9022777452278137,81.05000004882812,95.85199992919922
171,3.0923791023401113,0.9108617130088806,81.129999921875,95.89400010986328
172,3.052710037965041,0.9479087080192566,81.01200005371093,95.84800010986328
173,3.0646523787425113,0.9275116222953796,81.29999997070313,95.96799998046875
174,3.0505178708296556,0.9056856367492676,81.38399997070313,95.94599992919922
175,3.06196405337407,0.9053907576179504,81.12999994628906,95.80200000732422
176,3.0117578506469727,0.8890290698051453,81.45199989257813,95.95200005859375
177,2.998016329912039,0.9136776152038574,81.152,95.88400008544922
178,3.009301029718839,0.9255991547012329,81.41199994140625,95.99000000732421
179,3.0024658258144674,0.8959691961097718,81.5320000756836,95.97800008544922
180,2.9945607552161584,0.9226432334518433,81.57999989257813,95.97600008544921
181,3.014344673890334,0.8720306507110596,81.6679999975586,95.93400000732422
182,2.9808755196057835,0.8806172090339661,81.3259999194336,96.07399995605469
183,2.9757775435080895,0.8842584958076477,81.55599999511719,96.05399990478516
184,3.0188873547774095,0.8955323510360718,81.74600004394532,95.99000000732421
185,2.9996942465121927,0.8864319196891784,81.68200012451172,96.08600006103515
186,2.9925501071489773,0.8620990559005738,81.90599994140625,96.17200008544921
187,2.9568693821246805,0.8663196732139588,81.82000007080079,96.12199992919922
188,2.9661437639823327,0.8848002459716797,81.72799994140625,96.10600000732421
189,2.9302183389663696,0.898970189743042,81.787999921875,96.08400003417968
190,2.982291790155264,0.898637077960968,81.8679999194336,96.07000003417969
191,2.9291364137942972,0.8914895185279846,81.88999997070313,96.12599998046875
192,2.9678718952032237,0.9045106686401367,82.06000002441407,96.13800005859375
193,2.9405993773387027,0.8911021675300598,82.00600007324219,96.13399998291015
194,2.942032126279978,0.8845843422508239,82.24200002441407,96.20000000976563
195,2.9096350211363573,0.8615652515411377,82.03000015136719,96.26600000732422
196,2.978213383601262,0.8571838988304138,82.2619999951172,96.31399995605469
197,2.915033533022954,0.8975017419433594,82.32400007324219,96.32999990478515
198,2.90700048666734,0.8796381487083436,82.37400012939453,96.22200000732421
199,2.891255121964675,0.8519094377326966,82.1440001513672,96.24399990478516
200,2.8835707902908325,0.8565283406829834,82.1980001538086,96.35399992919922
201,2.922385866825397,0.8883370819091797,82.22599999511719,96.34400010986329
202,2.921412935623756,0.8623068803215027,82.43399994140626,96.26800003417969
203,2.855884689551133,0.8547059134101868,82.58599997070313,96.36800000732421
204,2.8865085656826315,0.8690887733078003,82.58200009765625,96.42799998046875
205,2.8650464919897227,0.8593528522872925,82.65200002197265,96.35600000732421
206,2.8669342352793765,0.8630787907600402,82.83999988769531,96.36600003173828
207,2.8705076346030602,0.8752160062599182,82.48999994384765,96.35800000732422
208,2.901468093578632,0.8615425992584228,82.84400007324219,96.39800000732421
209,2.8473127896969137,0.8606333646774292,82.68400020019531,96.49200000732422
210,2.8600355203335104,0.8499924317932129,82.50199997070312,96.38400005859376
211,2.835801885678218,0.8610882561302186,82.69799999511719,96.39600003173828
212,2.851739296546349,0.8339872376823425,82.96800017578126,96.49599990478515
213,2.8614711669775157,0.8409391230201722,82.8040001220703,96.54600005615234
214,2.8752744381244364,0.8455844373703003,82.92400012207031,96.59000003173828
215,2.8528220103337216,0.830362998752594,83.00600001953126,96.54000000732422
216,2.8457366869999814,0.8478204371833802,82.98000012695313,96.51600010986328
217,2.8469672203063965,0.8468237362861634,83.06399986328125,96.54400010986328
218,2.798969342158391,0.8438536651229859,83.34799994140624,96.55800005859375
219,2.7844803608380833,0.8434881903266906,83.21400012695312,96.59400006103516
220,2.8229461449843187,0.851990555973053,83.23599989257812,96.62600008544922
221,2.8255511980790358,0.8601199535369873,83.14200005126953,96.58200005859375
222,2.769902761165912,0.8383212358665466,83.32399994628906,96.50600018798828
223,2.788533880160405,0.8396528512954712,83.27799996826172,96.50800000732421
224,2.830211648574242,0.8376853758049011,83.44600001953125,96.66000003173828
225,2.7834478525015025,0.8274393226051331,83.33000004394532,96.56200008544921
226,2.8031968611937303,0.83849549451828,83.34400001708984,96.56799992919922
227,2.7822887163895826,0.8502991510391236,83.36599996826172,96.65399990478515
228,2.7865052681702833,0.8371617828369141,83.30399996582031,96.65000018798828
229,2.736697848026569,0.8339061546897888,83.46600001708984,96.67600003173828
230,2.756590568102323,0.8257232703781128,83.60399996826172,96.74800003173829
231,2.754254698753357,0.8320276457595825,83.5720000732422,96.66800003173829
232,2.717001960827754,0.8391983484077453,83.53200017578125,96.62000000732422
233,2.7290770640740027,0.8272213029289246,83.58199996826171,96.76800008544922
234,2.744955145395719,0.8214240999031067,83.49600012695312,96.77000000732421
235,2.7437443091319156,0.8427003208732605,83.48800015136719,96.67799995605469
236,2.6601342146213236,0.8298304134178162,83.58400004882813,96.70799995605469
237,2.7236390939125648,0.8376305979537964,83.87200014892578,96.78199992919922
238,2.6828131584020762,0.8278305959701538,83.7240000415039,96.82199998046875
239,2.686338681441087,0.8361137383460998,83.79000009765625,96.7519999584961
240,2.6599515951596775,0.8282428256988525,83.8899999951172,96.84599998291016
241,2.708692715718196,0.840486863193512,83.86400007324218,96.81600008544922
242,2.706350097289452,0.8143806471061706,83.84199999511719,96.72200008544922
243,2.675070423346299,0.820715518321991,83.93799999511718,96.78400005859375
244,2.67621019253364,0.8190324502944947,83.76800007324219,96.73399990478515
245,2.6871630870378933,0.8306452424812317,83.95400001953125,96.75600003173828
246,2.668807891698984,0.8138127730178834,83.83600001953126,96.82799990478516
247,2.670647089297955,0.8158876292228698,83.91600014892578,96.84599992919922
248,2.6516282008244443,0.8102354573631286,83.95800007080078,96.90400000732421
249,2.6651534667381873,0.8250773019981384,84.03800012207032,96.84999998046875
250,2.634173879256615,0.8186645620155334,84.04400009765625,96.84399998291016
251,2.6318732500076294,0.8051316935539246,84.12400007324219,96.82799995605468
252,2.6179944460208597,0.7980528848648071,84.06599999267578,96.82800003173828
253,2.647362910784208,0.834535809059143,84.21600001708984,96.90199992919922
254,2.6695059812985935,0.7965805632972718,84.12399994140625,96.86800000732421
255,2.6050742131013136,0.8134198690032959,84.21600009765625,96.84599992919922
256,2.628141458217914,0.8186293986320495,84.19600007080078,96.83400003173828
257,2.5968537147228536,0.8014547142791748,84.20799999267578,96.88799992919922
258,2.595403607075031,0.8086359707069397,84.36800004638671,96.89000003173828
259,2.5980473114893985,0.8042134921073913,84.22400004638672,96.88799998046875
260,2.579132290986868,0.8025791395187378,84.32800009765624,96.89399998046875
261,2.578222779127268,0.8010256945419312,84.27399991699218,96.87000000732422
262,2.583704893405621,0.8182663324546814,84.27199994140625,96.88799992919922
263,2.5933315478838406,0.8125042825698853,84.42200009765625,96.96999992919922
264,2.594082392179049,0.8039098625183105,84.48800001953126,96.92399998046875
265,2.590479172193087,0.7950776927566529,84.36800004394532,96.94000003173828
266,2.5677045033528256,0.7946233463287353,84.27600012451173,96.92599998046875
267,2.5831249952316284,0.7977374027442932,84.35800007324218,97.02199998046875
268,2.574748296004075,0.8000576438903808,84.36599989257813,96.96199998046875
269,2.572749523016123,0.8113612873840332,84.43400009765625,96.96200008544922
270,2.5436680500323954,0.8111577160453797,84.49600012207031,96.93799998046875
271,2.5695200791725745,0.8020232275581359,84.50399999511718,96.95399990234375
272,2.5755814955784726,0.8084232340240478,84.37200012207032,96.98400005859375
273,2.5361680342600894,0.8032256897354126,84.48599991455077,96.98399992919921
274,2.5335547832342296,0.8067118916130066,84.49200004394531,96.95800003173828
275,2.5463146613194394,0.8099793620681762,84.49399999267578,96.92399992919921
276,2.5408946642508874,0.8107297235488892,84.56000012451172,96.97399992919922
277,2.5601063875051646,0.79685419921875,84.54400001953125,96.94800000732423
278,2.5401773361059337,0.8168009343910217,84.58600015136719,96.91799992919921
279,2.5693678764196544,0.8146619980621338,84.57600012207031,97.00399992919922
280,2.4974713142101583,0.8080339497184753,84.60600020019531,96.95999992919921
281,2.5239082941642175,0.8044849347305297,84.65400004394532,96.98600008544922
282,2.5418753899060764,0.7914894469451904,84.64600007324219,96.99399998046874
283,2.5818122075154233,0.8061416247177124,84.64800012207031,96.94399998046875
284,2.5547984655086813,0.7970900491142273,84.6160000439453,96.95999992919921
285,2.5307810673346887,0.8120660132408142,84.62400004394532,96.95200008544921
286,2.568904161453247,0.793556639137268,84.68400009765625,96.94599998046876
287,2.5109933614730835,0.7803754375076294,84.68400004638671,96.99200005859375
288,2.5294809433130117,0.8061936453437805,84.63400014892578,97.01599992919922
289,2.5508905740884633,0.8128651645469666,84.67599996582031,96.97600000732422
290,2.5313611947573147,0.7819976863479614,84.75200001953125,97.03600008544922
291,2.5070205376698422,0.8001993773460389,84.75200009765625,97.01799992919922
292,2.5351219268945546,0.8031717999839783,84.68800009765626,96.98799992919922
293,2.480671396622291,0.8163573293685913,84.67800007080078,96.96999992919922
294,2.4951118505918064,0.7894428477478027,84.67400001953125,96.97799992919921
295,2.517219598476703,0.799906842327118,84.6820000732422,97.00399998046875
296,2.517501244178185,0.7853897292900085,84.69600001953125,96.98999992919921
297,2.4902893488223734,0.8003411549377442,84.72799996582032,97.00999998046875
298,2.5424841458980856,0.7879449011230468,84.69800012207031,97.01599998046875
299,2.490955499502329,0.7997499727249145,84.62600017578124,97.00399992919922
300,2.5226229337545543,0.798079895362854,84.67600009765626,96.98999992919921
301,2.4854222994584303,0.8056669466209412,84.73400004394531,96.95399992919921
302,2.51904087800246,0.8224670502853394,84.64000007080078,96.97400000732422
303,2.5248863238554735,0.7937660120773316,84.72000004394532,96.96400010986328
304,2.4830260826991153,0.8070458587837219,84.63000001708984,96.97000010986328
305,2.520942073601943,0.809012079486847,84.58200012207031,96.94400010986328
306,2.4892017474541297,0.795564078617096,84.64000017333984,96.96200010986328
307,2.529263872366685,0.8014956588554383,84.71000012207031,96.97800000732421
308,2.510762306360098,0.7898429228019714,84.71800009765624,96.95399992919921
309,2.516177461697505,0.7976002066612243,84.6880001220703,96.99199998046875
================================================
FILE: checkpoint/iformer_small/args.yaml
================================================
aa: rand-m9-mstd0.5-inc1
amp: false
apex_amp: false
aug_repeats: 0
aug_splits: 0
batch_size: 64
bce_loss: false
bce_target_thresh: null
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 1
class_map: ''
clip_grad: null
clip_mode: norm
color_jitter: 0.4
cooldown_epochs: 10
crop_pct: null
cutmix: 1.0
cutmix_minmax: null
data_dir: /dataset/imagenet-raw
dataset: ''
dataset_download: false
decay_epochs: 100
decay_rate: 0.1
dist_bn: reduce
drop: 0.0
drop_block: null
drop_connect: null
drop_path: 0.2
embed_dim: 384
epoch_repeats: 0.0
epochs: 300
eval_metric: top1
experiment: iformer_small
gp: null
hflip: 0.5
img_size: 224
initial_checkpoint: ''
input_size: null
interpolation: ''
jsd_loss: false
local_rank: 0
log_interval: 50
log_wandb: false
lr: 0.001
lr_cycle_decay: 0.5
lr_cycle_limit: 1
lr_cycle_mul: 1.0
lr_k_decay: 1.0
lr_noise: null
lr_noise_pct: 0.67
lr_noise_std: 1.0
mean: null
min_lr: 1.0e-06
mixup: 0.8
mixup_mode: batch
mixup_off_epoch: 0
mixup_prob: 1.0
mixup_switch_prob: 0.5
model: iformer_small
model_ema: false
model_ema_decay: 0.9998
model_ema_force_cpu: false
momentum: 0.9
native_amp: false
no_aug: false
no_ddp_bb: false
no_prefetcher: true
no_resume_opt: false
num_classes: null
opt: adamw
opt_betas: null
opt_eps: null
output: checkpoint
patience_epochs: 10
pin_mem: false
port: '25500'
pretrained: false
ratio:
- 0.75
- 1.3333333333333333
recount: 1
recovery_interval: 0
remode: pixel
reprob: 0.25
resplit: false
resume: ''
save_images: false
scale:
- 0.08
- 1.0
sched: cosine
seed: 42
smoothing: 0.1
split_bn: false
start_epoch: null
std: null
sync_bn: false
torchscript: false
train_interpolation: random
train_split: train
tta: 0
use_multi_epochs_loader: false
val_split: validation
validation_batch_size: null
vflip: 0.0
warmup_epochs: 5
warmup_lr: 1.0e-06
weight_decay: 0.05
worker_seeding: all
workers: 10
================================================
FILE: checkpoint/iformer_small/summary.csv
================================================
epoch,train_loss,eval_loss,eval_top1,eval_top5
0,6.9094244333413934,6.864997950286865,0.3860000048828125,1.5220000048828124
1,6.569643974304199,5.442198684692383,6.860000015258789,19.424000006103515
2,6.168546676635742,4.498686260910034,16.655999987792967,37.183999958496095
3,5.779889510228084,3.6264168195343016,27.648000013427733,52.987999997558596
4,5.510022530188928,3.047014978942871,35.92400004272461,62.721999880371094
5,5.173868894577026,2.683674906616211,43.09600005615234,70.35399994384765
6,4.986480364432702,2.4179258517074587,49.2040000390625,75.45399994873047
7,4.783603338094858,2.1778022815322875,53.34999995605469,79.12000002197266
8,4.67784386414748,2.046561869277954,56.38800017578125,81.42400006103516
9,4.566337640468891,1.971249479675293,58.90599991210937,83.36400008544922
10,4.354064602118272,1.9051295560455321,60.89799993652344,84.71400011230469
11,4.3395165846898,1.8117209907531737,61.95200006591797,85.39400020751953
12,4.263674002427321,1.763902481956482,63.86200010986328,86.63800010742187
13,4.20427758877094,1.7423424057769776,64.12000009033203,87.0420000830078
14,4.176988821763259,1.55658791595459,65.61000010742187,87.72
15,4.075015709950374,1.509385205974579,66.51999999023438,88.22600004394532
16,4.076420655617347,1.554206342601776,66.93600020263672,88.6279999951172
17,4.026540004290068,1.5549724096870423,67.67199991455078,89.07600001953125
18,3.9580479585207424,1.5148561889839172,67.98800004150391,89.18000014892579
19,3.957311511039734,1.4817134371757508,68.57399999755859,89.70000004638672
20,3.940097460379967,1.4594665077018738,69.18200001708985,89.91600001708984
21,3.9163199479763326,1.3787175881767273,69.98600007080078,90.29000004394531
22,3.838025588255662,1.378886430244446,69.57000009277344,90.28199999511719
23,3.810691558397733,1.3630643543052674,70.48800014648438,90.51600001953125
24,3.808310343669011,1.3218276855659485,70.53600006591797,90.73799999267578
25,3.835831018594595,1.361201537361145,70.79800001220703,90.82400004394532
26,3.7596489832951474,1.3758483277130127,71.0260000366211,91.13600012207031
27,3.7489256125230055,1.384459337501526,71.02999996337891,91.05399994140625
28,3.7584116642291727,1.3283219172668457,71.97800001220703,91.45600009521485
29,3.7523052509014425,1.3312880218887329,71.82800014160156,91.43600006835938
30,3.7243771736438456,1.296002097415924,71.90199990722657,91.47000001464843
31,3.7210132158719578,1.3768975380706787,72.24400001464844,91.61199999267578
32,3.6862072394444394,1.2716120315933228,72.53200000976562,91.7980001171875
33,3.714735608834487,1.2190303552818298,72.40000003417968,91.77000009033203
34,3.69110553081219,1.2783337624168396,72.73400003417969,91.73400006591797
35,3.6636599118892965,1.313292693901062,72.52200013671875,91.90200009033204
36,3.649338933137747,1.215225392036438,72.95600008544922,92.00399999023438
37,3.640485250032865,1.2873099118232727,73.19200000976562,92.10800004394531
38,3.649407689387982,1.2403657006073,73.21400006347656,92.20200014404297
39,3.606833659685575,1.235068807926178,73.37400013671875,92.18999996337891
40,3.7104046161358175,1.2011462058448792,73.44600001220704,92.3440000415039
41,3.6258209668673,1.3457986955070496,73.85400016601562,92.31399993896484
42,3.599777863575862,1.2382625227165223,73.50000000732422,92.20800001464843
43,3.6305395456460805,1.2349728396415711,73.44799985351563,92.32800001464844
44,3.6406014424103956,1.2155342070961,73.8500000415039,92.33399993896484
45,3.5644080822284403,1.2261458265686036,73.79799990722657,92.51599986083984
46,3.581205817369314,1.2117936054229737,74.01800008789063,92.4740000415039
47,3.5979900726905236,1.2034554265403747,74.19799993164062,92.5880000390625
48,3.627637432171748,1.264913904914856,74.18600008544922,92.46200009277344
49,3.580585874043978,1.213218408164978,74.36400011474609,92.62800009521484
50,3.5658984092565684,1.2477872757339477,74.36799991210937,92.71800006591796
51,3.5925065829203677,1.1601819841575622,74.7800000341797,92.90200011962891
52,3.535925791813777,1.2709670811653138,74.69200011230468,92.8500001977539
53,3.57413322191972,1.2347692168426514,74.53600013183593,92.70200006591797
54,3.569970341829153,1.2118336687660218,74.42800001220704,92.91000017333984
55,3.4951596718568068,1.232526140346527,74.52799998046875,92.87799993896485
56,3.528607414318965,1.2083224132537842,74.58200013671875,92.87600006835937
57,3.5129952339025645,1.2047773654174805,74.81600010986328,92.98000016845702
58,3.5265593253649197,1.1868788903999328,75.10600000488282,93.03999998779297
59,3.527732014656067,1.1377841114616394,75.05200003417968,93.12600011962891
60,3.5698564511079054,1.2140665634918213,75.15199995849609,93.22600009033204
61,3.5069606395868154,1.212316800441742,75.02599998046875,93.09600006347657
62,3.5076429568804226,1.207794077129364,75.30800006103516,93.16800016845703
63,3.5095110673170824,1.1751184555625915,75.05800008300781,92.98200006591797
64,3.495274332853464,1.1663758554077148,75.43800009033203,93.19200001220703
65,3.496675436313336,1.156433831691742,75.37000001220703,93.27800001220703
66,3.513992263720586,1.1885098503875733,75.70400010986329,93.41200008789062
67,3.504016555272616,1.1901980477905274,75.6640000341797,93.48599990966797
68,3.5192229656072764,1.2184467774772645,75.40999995361328,93.28399993408203
69,3.4537463738368106,1.144714741153717,75.72999997802734,93.42600001220703
70,3.50569604910337,1.1589157112121582,75.66200005371094,93.31400006591797
71,3.5316452154746423,1.0998509642791747,76.00400008544922,93.43200011962891
72,3.4660944846960215,1.125511461544037,76.02200003173829,93.37199996337891
73,3.455447737987225,1.2478693104553222,75.81600005615235,93.34800016845703
74,3.474152766741239,1.2159589889144897,75.98600014404298,93.4120000390625
75,3.4426536560058594,1.135044396018982,75.85400003173828,93.53000008789063
76,3.5005324253669152,1.1231057829284667,75.93800005371094,93.47800016601562
77,3.4688816253955546,1.1647394517326355,76.06999992919921,93.5140000390625
78,3.4501122419650736,1.0868774487495423,76.11599992675781,93.48000001220703
79,3.4457626984669614,1.1312008828926086,76.07600013671875,93.55400003662109
80,3.3960279593100915,1.0933656254005433,76.20400010986329,93.6460000390625
81,3.4835668068665724,1.1270457012176514,76.2800000830078,93.7440001147461
82,3.4436028645588803,1.126568383769989,76.30000010742188,93.69600006591797
83,3.4565229690991917,1.1777190805244446,76.37800005371093,93.68800011474609
84,3.414357983149015,1.188438509941101,76.37800010986328,93.59000019287109
85,3.459013196138235,1.2144730036735534,76.43800008544922,93.6420000366211
86,3.45191019314986,1.1025224459266663,76.35400005859375,93.5240000390625
87,3.412708933536823,1.1987829148101807,76.59799992675781,93.79000006347657
88,3.466808089843163,1.1002231981658936,76.82800010986328,93.8160000390625
89,3.4394527673721313,1.0912003845787048,76.66200010986329,93.84000009033203
90,3.434341018016522,1.1553971852874756,76.66199998291016,93.88400003662109
91,3.394513542835529,1.0940003709220887,76.6640000830078,93.77599998291015
92,3.434716270520137,1.107185283679962,76.77800005859375,94.02400000976563
93,3.3863757940439076,1.1242798097229003,76.96600002929688,93.85199998291016
94,3.430507412323585,1.1021452717971802,76.88600005371094,93.89400003173829
95,3.381525305601267,1.1088750526237487,77.00400008544922,93.92400016845703
96,3.387681630941538,1.1086897051239013,76.98400005371094,93.8740001147461
97,3.407272118788499,1.0778170557975768,77.01400020996094,94.00999993408203
98,3.4062380607311544,1.0873821159362793,76.81799995361328,93.92000001220703
99,3.3701010025464573,1.1149417749214172,76.86600010986328,94.01000019287109
100,3.3691365076945377,1.1462936076545716,76.86200003173828,93.82600013916016
101,3.375997204046983,1.0753991451454163,76.94400018554687,94.07800009033203
102,3.357974116618817,1.135077677707672,77.24800008300781,94.0820001171875
103,3.3830726513495812,1.1030329714012146,77.17000008056641,94.06600008789063
104,3.367032500413748,1.082635837879181,77.20999989746093,94.11800001220703
105,3.3800272391392636,1.0921122213745118,77.23400008300781,93.99600001220703
106,3.3662016024956336,1.1072591597175598,77.01600000732422,94.1160001147461
107,3.3502725729575524,1.1785977474021911,77.32,94.09800008789063
108,3.340722441673279,1.0760896536827087,77.47400005859375,94.14800014404297
109,3.3812398176926832,1.083845131263733,77.57799998046875,94.2759999584961
110,3.3553919150279117,1.1084872850608827,77.42600009033202,94.15200001464844
111,3.325396336041964,1.1166858360671996,77.54600012939453,94.14599993164063
112,3.349822860497695,1.085412045841217,77.49600010253906,94.20199998291015
113,3.3029517852343044,1.1392902358436585,77.62800005859376,94.40200006103515
114,3.3522927577678976,1.142200709590912,77.652000078125,94.34400016601562
115,3.316668547116793,1.074139283103943,77.68600005371094,94.45799995849609
116,3.326545972090501,1.07850544713974,77.77000008300782,94.21599998291016
117,3.315403158848102,1.145295484161377,77.74600005615234,94.41199998291016
118,3.315896850365859,1.0736533163452149,77.73600005126953,94.33600000732422
119,3.332959147600027,1.0669116265106202,77.61800013916016,94.34399993408203
120,3.342320543069106,1.0954551971435547,78.09600010986328,94.49600006591797
121,3.269335691745465,1.109053556213379,78.05999987304688,94.4940001147461
122,3.299822678932777,1.059656356754303,78.30600013427734,94.56000011474609
123,3.3047666274584255,1.120321944065094,77.90000005371094,94.41400000976563
124,3.318658021780161,1.0599626556968689,78.116000078125,94.53599998291016
125,3.3260184893241296,1.104176430568695,78.12400000732421,94.45600003662109
126,3.327208610681387,1.080815853767395,77.98200008300782,94.48600008789063
127,3.2752062265689554,1.0735882032585145,77.95600008300781,94.48600006591796
128,3.2862806778687697,1.1019770643424989,78.19400015625,94.37000008789063
129,3.300983373935406,1.1177044029426575,77.99000013183594,94.31200008789062
130,3.2768982557150035,1.0396747249031066,78.43000000488281,94.7400000366211
131,3.2978845559633694,1.0482300101852418,78.12200002441406,94.59199998291015
132,3.328818284548246,1.0577856283950806,78.08999995361329,94.65000001220703
133,3.2839138141045203,1.0800661448669433,78.22400012939453,94.58000006103515
134,3.2449550261864295,1.0812386049079894,78.42199994628906,94.67800006103515
135,3.273659210938674,0.9925222877883911,78.83200018066407,94.73800013916015
136,3.279651779394883,1.1514595987319947,78.44199994628906,94.62399998291015
137,3.249285734616793,1.0616181856918334,78.67999995117188,94.70000003662109
138,3.2401852149229784,1.030422007408142,78.55200002441406,94.73399990478515
139,3.2287596739255466,1.11498403093338,78.648,94.83400006103516
140,3.2589325996545644,1.1934773761367798,78.630000078125,94.80000013916016
141,3.235685348510742,1.0107798411178588,78.718,94.85000006103516
142,3.255448185480558,1.0406203583145142,78.82800008300781,94.84600003417968
143,3.2458052726892324,1.1242790790367125,78.75999992675781,94.8079999584961
144,3.243349625514104,1.0674845689964294,78.73600005371094,94.85800008544922
145,3.208817793772771,1.053589370918274,78.48999994873047,94.74800000976562
146,3.247132796507615,1.0643479359817505,79.04200002441407,94.94800005859375
147,3.2469917077284594,1.0336022610664368,79.01399994873047,94.91199998291016
148,3.246920136305002,1.0803122089195252,78.92200005371093,94.84000011474609
149,3.2142061361899743,1.0599778313064576,78.99600010498047,94.92599995849609
150,3.217764240044814,1.06389358127594,79.10400002929687,94.95799995849609
151,3.216148073856647,1.0799826690673828,79.1300000024414,94.83200016357422
152,3.202606833898104,1.000363338546753,79.16200004882812,94.87800006103515
153,3.1680813752687893,1.026247396221161,79.16999989746094,94.98600003417968
154,3.2302153752400327,1.05421796377182,79.19600005371093,94.90400006103516
155,3.2140613060731154,1.034919361228943,79.1000000024414,94.9480001147461
156,3.212250150167025,1.0543945582008363,79.29199989746094,94.99999990478516
157,3.189672048275287,0.9840240316200256,79.35799997558594,95.16799995605469
158,3.1594845056533813,1.0651922596740722,79.29200010498047,94.97399990478516
159,3.169252661558298,1.0438192999267577,79.34,95.15399995605469
160,3.1778015265097985,0.9776468455314636,79.22999994628906,95.04400008544921
161,3.184331774711609,0.9410803900527954,79.59000010253907,95.23599998291016
162,3.172625651726356,1.1021912324523926,79.44200005126953,95.08000006103515
163,3.172793911053584,1.006711642932892,79.650000078125,95.21000003417969
164,3.1505272480157704,1.0151854071998596,79.25200004882812,95.04000016357422
165,3.2113893582270694,1.002980606880188,79.53400010742187,95.18599995605469
166,3.203436310474689,1.069907742919922,79.4719999243164,95.11200008544922
167,3.1434664726257324,0.9765472800827026,79.68800013183593,95.34800003173828
168,3.2033794017938466,0.954885217037201,79.74200010253907,95.34199995605469
169,3.157770037651062,1.015079356956482,79.69200018066407,95.14799992919922
170,3.1801951298346887,1.0224423956489563,79.84799997558594,95.30400006103515
171,3.183558913377615,1.0809200729370116,79.80000002929688,95.23600000732422
172,3.166443494650034,0.970890617313385,79.75400015625,95.26800005859376
173,3.1588093500870924,0.9810822336959839,79.93399997558593,95.40400006103516
174,3.1534451246261597,0.9683314956474304,80.04599997314453,95.36200006103516
175,3.176486510496873,0.9815351949119567,79.97000010498047,95.32399995849609
176,3.133654860349802,1.033868189277649,80.04200004882813,95.46000008300781
177,3.0971615681281457,1.057138032245636,79.988,95.32999988037109
178,3.0857103237738976,1.0150846409988403,80.01399987060547,95.31000024169921
179,3.1099898723455577,0.9767458859634399,80.17400010253907,95.46799990478516
180,3.0863860020270715,0.9960832705497742,80.35000010253906,95.4619998803711
181,3.1192220632846537,0.995442993927002,80.112,95.46599992919921
182,3.1065458059310913,0.9503561297225952,80.30000013183594,95.50400008789063
183,3.084474407709562,0.9327744506835938,80.504,95.55199990478516
184,3.1169589299422045,0.987274864768982,80.31800020507812,95.52000008544921
185,3.105303324185885,0.9985076886749268,80.57799994628907,95.55799992919921
186,3.086181319676913,0.9722424831962585,80.55400010498047,95.54000000732422
187,3.0753506880540113,1.005631930141449,80.45999997070312,95.51000008544922
188,3.084871640572181,1.0012442365074157,80.48399994628906,95.65999995849609
189,3.0642203917870154,0.9958682190895081,80.54799996582031,95.58800003173828
190,3.115300545325646,0.9633640588378907,80.41400002685548,95.60000003417969
191,3.0352935699316173,0.9269049572944641,80.61600001953126,95.64400008544922
192,3.096716523170471,0.9777986540985107,80.72600010742188,95.65000000976562
193,3.0447621712317834,1.020225089454651,80.82200004882813,95.61000000732422
194,3.0424715830729556,0.9350263250541687,80.91199994628906,95.66599992919922
195,3.04462403517503,0.9920734120368957,80.93400012695312,95.65400008544921
196,3.0816985093630276,0.9962930574035644,80.73600002441407,95.57600008544922
197,3.043275053684528,1.0144961289787293,80.85000002441406,95.69400008544922
198,3.0301405558219323,0.921901376285553,80.81000007080078,95.72999992919922
199,3.021504677259005,0.9390514798736572,80.93599991699219,95.72799998046875
200,3.0032437581282396,0.9052555331039429,80.92400002441406,95.72999990478516
201,3.0524399005449734,0.9862956017303467,80.98800002441406,95.70599992919922
202,3.04898725106166,0.9408524942779541,80.8440000756836,95.75799992919922
203,2.9965205651063185,0.9691278018569947,81.11999997070312,95.80000010986328
204,3.0177720876840444,0.9536201685333252,81.08200010253907,95.85800008544922
205,3.0006031714952908,0.8881150724983216,81.10200009765624,95.74999992919922
206,3.0060267448425293,1.0253046160316468,81.081999921875,95.70000010986328
207,2.9632287667347836,0.9307623635292053,81.20199997070313,95.88400010986328
208,3.014235395651597,0.980958409614563,81.24000006835938,95.85200008300781
209,2.9804076965038595,0.9197839248657227,81.26800004638672,95.75000016357421
210,3.0220154248751125,0.9738036029434204,81.32600002441406,95.93999995605469
211,2.9515054042522726,0.8529443983078003,81.53800001953125,95.91200010986329
212,3.0006972001149106,0.9346630925750733,81.41200010253907,95.92800008544921
213,2.993186171238239,0.9045049831581116,81.31800002197265,95.89400005859375
214,3.0099153335277853,0.9540823571968079,81.43400012207032,95.90000008544922
215,2.9962432842988234,0.961731686630249,81.57800007568359,95.9020000830078
216,2.964731344809899,0.9199039099502564,81.68000009765625,95.98400018798829
217,2.976665368446937,0.9713804598045349,81.63999996826172,95.97400010986328
218,2.938941423709576,0.9095170165824891,81.77400017822265,95.99000010986327
219,2.9089973247968235,0.9290500170707703,81.55399999511718,96.03600003173828
220,2.972458839416504,0.9646587027359009,81.6720000756836,95.95600000488281
221,2.9590316002185526,0.9856760236167907,81.76999997314454,96.00800013671875
222,2.924618986936716,0.8967197008514405,81.8499999975586,95.98600000732422
223,2.9153974881538978,0.9025303680419922,81.80200012695313,96.00800008544923
224,2.9513279658097487,0.9642524267578125,81.81600002441407,95.99200000732422
225,2.925851684350234,0.9187241349601746,81.70800007324219,96.08600005859375
226,2.944014842693622,0.9485834597969055,81.83400004882813,96.16800005859375
227,2.9328464544736423,0.9324082082939148,81.9540000756836,96.17800000732421
228,2.925115887935345,0.8715889813232421,82.11000005371093,96.16800003173829
229,2.882812270751366,0.9478203728675842,81.99599999511719,96.19000018798828
230,2.893248823972849,0.8939307011795044,82.06799997070313,96.13000008544923
231,2.909112334251404,0.9412263136482238,82.08400004882813,96.15599998046875
232,2.8751327991485596,0.9333421780395508,82.09000010253907,96.07600006103516
233,2.8806763612307034,0.8944049692344666,82.09599999511718,96.09800000732422
234,2.8830181176845846,0.8789783415031434,82.22800002197266,96.13400000732422
235,2.883752153469966,0.9032074878883362,82.192,96.12400000732421
236,2.8425116630700917,0.9093195487785339,82.38400007568359,96.11000010986328
237,2.8728231375034037,0.9299646599769592,82.35999999511719,96.20400003173827
238,2.854443614299481,0.8940147971343995,82.43400007324219,96.19600005859375
239,2.8340333149983334,0.8646629189872742,82.15599997070312,96.22799992919921
240,2.834600338569054,0.9099975925445557,82.34599997070312,96.26400000732421
241,2.8750672890589786,0.9240939585876465,82.35999994628907,96.22399990478516
242,2.881342814518855,0.9277827157592774,82.29400004882812,96.18600000732422
243,2.800852289566627,0.9207671698379517,82.42399997070312,96.29199995605468
244,2.826339785869305,0.8509364444160461,82.48799994140624,96.32199992919922
245,2.832765579223633,0.8535477470588684,82.51800002197265,96.30199992919921
246,2.835671140597417,0.8540577458381653,82.61000007324219,96.34400010986329
247,2.835276411129878,0.88917942527771,82.58200007324218,96.36600016357421
248,2.8517292187764096,0.8996915315628051,82.60200004638672,96.34800005859375
249,2.83868577847114,0.8852077262115479,82.71800018066406,96.34200000732422
250,2.795175653237563,0.868073853187561,82.7640000756836,96.36400003417968
251,2.8127903846594005,0.9428228170013427,82.69600015136719,96.34999992919921
252,2.7924872911893406,0.8700160450935364,82.73600002441407,96.37400000732421
253,2.8241086097864003,0.9485920486831665,82.66599991699219,96.35400008544921
254,2.820255627998939,0.8666281459808349,82.63400012207032,96.34800013671875
255,2.767980071214529,0.9426617022705078,82.69000001953125,96.39400010986328
256,2.7815559827364407,0.8850377077674866,82.79000001953125,96.40200008544922
257,2.7684466472038856,0.9531537490272523,82.81200012695312,96.39999998291016
258,2.776091988270099,0.9598657544517517,82.78800007324219,96.41400000732422
259,2.7829883373700657,0.8710285722160339,82.94200017578125,96.40200010986328
260,2.7462159945414615,0.8686052158546448,82.83200002197266,96.44399998046875
261,2.745678883332473,0.8464594898796082,83.03600012451172,96.44399998046875
262,2.746650503231929,0.9291099748039245,82.88800014892578,96.4680001123047
263,2.7892674849583554,0.9198091730690002,82.92000009765626,96.48000010986328
264,2.7777191950724673,0.876633308506012,83.00000009765625,96.48200000732422
265,2.790822093303387,0.8607185919570923,82.98600007324218,96.47999995605468
266,2.7587546568650465,0.8637187901306153,83.12800012207032,96.47200010986329
267,2.7561703461867113,0.9178548712730408,83.01000004394531,96.47800000732421
268,2.753976033284114,0.9308101778793335,83.02200007324218,96.47800005859375
269,2.7478087773689857,0.905781350402832,83.07399999511719,96.48400005859375
270,2.7359400345728946,0.8987113190078735,83.10199999511718,96.49800000732422
271,2.7687891721725464,0.8974916218185425,83.10600001953125,96.47800005859375
272,2.7437755236258874,0.8698487334251404,83.1459999951172,96.53000010986328
273,2.740336509851309,0.8323400824928283,83.19800015136718,96.50200005859375
274,2.7227883522327128,0.9540283364486695,83.18800015136719,96.49400000732422
275,2.7484019077741184,0.8897652346038818,83.22799999511719,96.52200005859375
276,2.7232350294406595,0.8671867720222474,83.13600015136718,96.55000005859375
277,2.7706412352048435,0.914162262840271,83.18800010009765,96.52600005859375
278,2.7158928284278283,0.8688316952896118,83.13200010009766,96.55800000732422
279,2.757467407446641,0.8733858386993408,83.24800004638672,96.50800005859375
280,2.680275009228633,0.8916000234222412,83.22200002197266,96.53800005859375
281,2.7301262342012844,0.9303448626327515,83.25600007324219,96.54400005859375
282,2.7311555605668287,0.9017908967781066,83.27000007324219,96.53000005859376
283,2.778689439480121,0.8907459851264954,83.33400012451172,96.55800005859375
284,2.730280564381526,0.8617753137397766,83.29599994384766,96.55600000732422
285,2.6996115996287418,0.9456994646072387,83.32199994384766,96.54800000732422
286,2.7530321524693417,0.8394892182922363,83.29600002197266,96.53600000732422
287,2.698102043225215,0.871575404548645,83.29799999511718,96.52800000732422
288,2.7162117774669943,0.9279066291618348,83.29799994384766,96.51800000732422
289,2.731216476513789,0.870179077796936,83.27600007324219,96.53200005859375
290,2.7172455145762515,0.8693408039665222,83.32000009765625,96.56200000732422
291,2.7217825100972104,0.862339599533081,83.32000002197266,96.51200000732422
292,2.738798902584956,0.8762007956314087,83.29199994384766,96.54600000732422
293,2.6903779139885535,0.9252629530525207,83.30799994384766,96.55800000732422
294,2.7017551843936625,0.8361762863922119,83.3040000732422,96.57200000732422
295,2.7112444455807028,0.8382016857719421,83.30800012451172,96.52600000732421
296,2.7227955231299767,0.8906204897880554,83.32800012451172,96.53000000732422
297,2.702854578311627,0.8682446762657166,83.31999994384766,96.54200000732422
298,2.7227264734414907,0.9243216412544251,83.33200002197266,96.49400000732422
299,2.6936778105222263,0.813800574092865,83.33000002197265,96.56000000732422
300,2.7342205964601956,0.9972933919715882,83.30199999511719,96.50200000732421
301,2.696374086233286,0.8523485798072815,83.36800012451172,96.52000000732421
302,2.732568465746366,0.9299849660491943,83.33600004638672,96.53600000732422
303,2.7426971655625563,0.8711758170700074,83.34400007324219,96.53000000732422
304,2.7026734902308536,0.9297433455467224,83.31600012451172,96.51400000732421
305,2.7032283636239858,0.9165915111351013,83.37599994384766,96.53600000732422
306,2.7004379217441263,0.8460414072036743,83.38199999511718,96.55800000732422
307,2.7239219958965597,0.8808569229316712,83.37600012451172,96.53800000732421
308,2.6918300573642435,0.8398378726959228,83.37200002197265,96.55200000732422
309,2.699970080302312,0.8693196088981628,83.35399994384765,96.55000000732421
================================================
FILE: checkpoint_384/iformer_base_384/args.yaml
================================================
aa: rand-m9-mstd0.5-inc1
amp: false
apex_amp: true
aug_repeats: 3
aug_splits: 0
batch_size: 32
bce_loss: false
bce_target_thresh: null
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 1
class_map: ''
clip_grad: 1.0
clip_mode: norm
color_jitter: 0.4
cooldown_epochs: 10
crop_pct: null
cutmix: 0.1
cutmix_minmax: null
data_dir: /dataset/imagenet-raw
dataset: ''
dataset_download: false
decay_epochs: 30.0
decay_rate: 0.1
dist_bn: reduce
drop: 0.0
drop_block: null
drop_connect: null
drop_path: 0.5
embed_dim: 384
epoch_repeats: 0.0
epochs: 20
eval_metric: top1
experiment: iformer_base_384
gp: null
hflip: 0.5
img_size: 384
initial_checkpoint: checkpoint/iformer_base/model_best.pth.tar
input_size: null
interpolation: ''
jsd_loss: false
local_rank: 0
log_interval: 50
log_wandb: false
lr: 5.0e-06
lr_cycle_decay: 0.5
lr_cycle_limit: 1
lr_cycle_mul: 1.0
lr_k_decay: 1.0
lr_noise: null
lr_noise_pct: 0.67
lr_noise_std: 1.0
mean: null
min_lr: 5.0e-07
mixup: 0.1
mixup_mode: batch
mixup_off_epoch: 0
mixup_prob: 1.0
mixup_switch_prob: 0.5
model: iformer_base_384
model_ema: false
model_ema_decay: 0.9998
model_ema_force_cpu: false
momentum: 0.9
native_amp: false
no_aug: false
no_ddp_bb: false
no_prefetcher: true
no_resume_opt: false
num_classes: null
opt: adamw
opt_betas: null
opt_eps: 1.0e-08
output: checkpoint_384
patience_epochs: 10
pin_mem: false
port: '25500'
pretrained: false
ratio:
- 0.75
- 1.3333333333333333
recount: 1
recovery_interval: 0
remode: pixel
reprob: 0.25
resplit: false
resume: ''
save_images: false
scale:
- 0.08
- 1.0
sched: cosine
seed: 42
smoothing: 0.1
split_bn: false
start_epoch: null
std: null
sync_bn: false
torchscript: false
train_interpolation: random
train_split: train
tta: 0
use_multi_epochs_loader: false
val_split: validation
validation_batch_size: null
vflip: 0.0
warmup_epochs: 0
warmup_lr: 2.0e-08
weight_decay: 1.0e-08
worker_seeding: all
workers: 10
================================================
FILE: checkpoint_384/iformer_base_384/summary.csv
================================================
epoch,train_loss,eval_loss,eval_top1,eval_top5
0,2.162344753742218,0.6423477774047851,85.416,97.554
1,2.14225176152061,0.6349599724960328,85.544,97.552
2,2.1576704359522054,0.63505445728302,85.514,97.558
3,2.151329924078549,0.6296176392364502,85.482,97.552
4,2.16335995758281,0.6267058710479736,85.574,97.572
5,2.123039128733616,0.6339057161712647,85.602,97.582
6,2.1556805185243193,0.6313015496826172,85.622,97.6
7,2.1183225535878947,0.6357844618606567,85.594,97.558
8,2.143087193077686,0.6312218860626221,85.548,97.582
9,2.0986259544596955,0.625974910697937,85.654,97.574
10,2.111677838306801,0.6280780986785889,85.61,97.62
11,2.139703644256966,0.6327868656158447,85.698,97.598
12,2.1040731925590364,0.6322624744033813,85.684,97.592
13,2.1542149863991082,0.6258501932525635,85.684,97.594
14,2.1099117760564767,0.6318398251342774,85.672,97.57
15,2.1123981055091408,0.6286558788299561,85.654,97.628
16,2.1304921811702204,0.6399081578063965,85.676,97.594
17,2.138498722338209,0.6392654892730713,85.702,97.604
18,2.1148997580303863,0.6415886449813842,85.67,97.604
19,2.1841310519798127,0.6299422284317017,85.648,97.612
20,2.114746145173615,0.6347110688018799,85.674,97.604
21,2.141511425083759,0.6325080072784424,85.68,97.604
22,2.1408919329736746,0.6330610972595215,85.696,97.61
23,2.14104422751595,0.6302861064147949,85.702,97.604
24,2.1165792801800896,0.6278342678070068,85.69,97.596
25,2.1689459833444333,0.6362991049194336,85.706,97.618
26,2.106649726044898,0.6317873718261718,85.69,97.612
27,2.130349175602782,0.6273881248474121,85.698,97.602
28,2.101420769504472,0.629305263748169,85.692,97.6
29,2.1075617484017912,0.6330888885116577,85.67,97.614
================================================
FILE: checkpoint_384/iformer_large_384/args.yaml
================================================
aa: rand-m9-mstd0.5-inc1
amp: false
apex_amp: true
aug_repeats: 3
aug_splits: 0
batch_size: 32
bce_loss: false
bce_target_thresh: null
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 1
class_map: ''
clip_grad: 1.0
clip_mode: norm
color_jitter: 0.4
cooldown_epochs: 10
crop_pct: null
cutmix: 0.1
cutmix_minmax: null
data_dir: /dataset/imagenet-raw
dataset: ''
dataset_download: false
decay_epochs: 30.0
decay_rate: 0.1
dist_bn: reduce
drop: 0.0
drop_block: null
drop_connect: null
drop_path: 0.6
embed_dim: 384
epoch_repeats: 0.0
epochs: 20
eval_metric: top1
experiment: iformer_large_384
gp: null
hflip: 0.5
img_size: 384
initial_checkpoint: checkpoint/iformer_large/model_best.pth.tar
input_size: null
interpolation: ''
jsd_loss: false
local_rank: 0
log_interval: 50
log_wandb: false
lr: 1.0e-05
lr_cycle_decay: 0.5
lr_cycle_limit: 1
lr_cycle_mul: 1.0
lr_k_decay: 1.0
lr_noise: null
lr_noise_pct: 0.67
lr_noise_std: 1.0
mean: null
min_lr: 1.0e-06
mixup: 0.1
mixup_mode: batch
mixup_off_epoch: 0
mixup_prob: 1.0
mixup_switch_prob: 0.5
model: iformer_large_384
model_ema: false
model_ema_decay: 0.9998
model_ema_force_cpu: false
momentum: 0.9
native_amp: false
no_aug: false
no_ddp_bb: false
no_prefetcher: true
no_resume_opt: false
num_classes: null
opt: adamw
opt_betas: null
opt_eps: 1.0e-08
output: checkpoint_384
patience_epochs: 10
pin_mem: false
port: '25500'
pretrained: false
ratio:
- 0.75
- 1.3333333333333333
recount: 1
recovery_interval: 0
remode: pixel
reprob: 0.25
resplit: false
resume: ''
save_images: false
scale:
- 0.08
- 1.0
sched: cosine
seed: 42
smoothing: 0.1
split_bn: false
start_epoch: null
std: null
sync_bn: false
torchscript: false
train_interpolation: random
train_split: train
tta: 0
use_multi_epochs_loader: false
val_split: validation
validation_batch_size: null
vflip: 0.0
warmup_epochs: 0
warmup_lr: 2.0e-08
weight_decay: 1.0e-08
worker_seeding: all
workers: 10
================================================
FILE: checkpoint_384/iformer_large_384/summary.csv
================================================
epoch,train_loss,eval_loss,eval_top1,eval_top5
0,2.1128095227938433,0.6339371264839172,85.65200001953124,97.5399999609375
1,2.112883148285059,0.6371526368713379,85.67800004394532,97.5179999609375
2,2.081270286670098,0.6324703786468506,85.65800001953124,97.52000003417969
3,2.0541468331447015,0.6338681053352356,85.69799999511719,97.5360000341797
4,2.1174410466964426,0.6423702626800537,85.70999999267578,97.52200000976562
5,2.06852941100414,0.6357672811508178,85.6640000390625,97.57600003417969
6,2.0652733536866994,0.6328511357688904,85.69000004394532,97.52999998535157
7,2.090355485677719,0.6328032349014282,85.73000006835937,97.57000001220703
8,2.0464458832373986,0.6333639372062683,85.70000001708985,97.5600000341797
9,2.1092736858588,0.6316976878166198,85.66799996582031,97.57000000976562
10,2.0740141524718356,0.6331033121871948,85.74800004394531,97.51800000976563
11,2.0925170343655806,0.6338267395019531,85.76999999511719,97.54000000976562
12,2.0762338707080255,0.6300550039291382,85.79600004394531,97.58400000976563
13,2.0655095875263214,0.6340151790809632,85.72400001953125,97.56800000976563
14,2.0131444334983826,0.6341021949386597,85.77399999023437,97.56400000976562
15,2.045399464093722,0.6350232722473145,85.78200004394532,97.56200000976563
16,2.0737479076935696,0.6348873478507996,85.78999999023438,97.56200000976563
17,2.0775972329653225,0.6329525914382934,85.7680000439453,97.55200000976562
18,2.0284574719575734,0.6307329417991638,85.77000004394532,97.56600000976563
19,2.0756700291083408,0.6293472344970703,85.79199996582031,97.56600000976563
20,2.0585412039206576,0.6404127075767517,85.81800001464843,97.56200000976563
21,2.107305421279027,0.6296587294960022,85.83400001708985,97.56400000976562
22,2.0651893615722656,0.6286195637512207,85.81999999023438,97.55600000976563
23,2.0812353377158823,0.6351058733177185,85.80600004394532,97.56200000976563
24,2.030002793440452,0.630206987991333,85.84400001708984,97.55400000976563
25,2.037764448385972,0.6358882413291931,85.80599999267578,97.56200000976563
26,2.0813280573258033,0.6310837986183166,85.77999999267578,97.55800000976562
27,2.06927875830577,0.6295442185211182,85.83399999267579,97.56600000976563
28,2.0359999583317685,0.6304517384147644,85.80800001708984,97.55800000976562
29,2.0422337972200832,0.6342641407775879,85.80799999267578,97.53600000976563
================================================
FILE: checkpoint_384/iformer_small_384/args.yaml
================================================
aa: rand-m9-mstd0.5-inc1
amp: false
apex_amp: false
aug_repeats: 0
aug_splits: 0
batch_size: 32
bce_loss: false
bce_target_thresh: null
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 1
class_map: ''
clip_grad: 1.0
clip_mode: norm
color_jitter: 0.4
cooldown_epochs: 10
crop_pct: null
cutmix: 0.1
cutmix_minmax: null
data_dir: /dataset/imagenet-raw
dataset: ''
dataset_download: false
decay_epochs: 100
decay_rate: 0.1
dist_bn: reduce
drop: 0.0
drop_block: null
drop_connect: null
drop_path: 0.3
embed_dim: 384
epoch_repeats: 0.0
epochs: 20
eval_metric: top1
experiment: iformer_small_384
gp: null
hflip: 0.5
img_size: 384
initial_checkpoint: checkpoint/iformer_small/iformer_small_checkpoint.pth
input_size: null
interpolation: ''
jsd_loss: false
local_rank: 0
log_interval: 50
log_wandb: false
lr: 1.0e-05
lr_cycle_decay: 0.5
lr_cycle_limit: 1
lr_cycle_mul: 1.0
lr_k_decay: 1.0
lr_noise: null
lr_noise_pct: 0.67
lr_noise_std: 1.0
mean: null
min_lr: 1.0e-06
mixup: 0.1
mixup_mode: batch
mixup_off_epoch: 0
mixup_prob: 1.0
mixup_switch_prob: 0.5
model: iformer_small_384
model_ema: false
model_ema_decay: 0.9998
model_ema_force_cpu: false
momentum: 0.9
native_amp: false
no_aug: false
no_ddp_bb: false
no_prefetcher: true
no_resume_opt: false
num_classes: null
opt: adamw
opt_betas: null
opt_eps: null
output: checkpoint_384
patience_epochs: 10
pin_mem: false
port: '25500'
pretrained: false
ratio:
- 0.75
- 1.3333333333333333
recount: 1
recovery_interval: 0
remode: pixel
reprob: 0.25
resplit: false
resume: ''
save_images: false
scale:
- 0.08
- 1.0
sched: cosine
seed: 42
smoothing: 0.1
split_bn: false
start_epoch: null
std: null
sync_bn: false
torchscript: false
train_interpolation: random
train_split: train
tta: 0
use_multi_epochs_loader: false
val_split: validation
validation_batch_size: null
vflip: 0.0
warmup_epochs: 0
warmup_lr: 2.0e-08
weight_decay: 1.0e-08
worker_seeding: all
workers: 10
================================================
FILE: checkpoint_384/iformer_small_384/summary.csv
================================================
epoch,train_loss,eval_loss,eval_top1,eval_top5
0,2.3478839855927687,0.6923815357589722,84.32999999023437,97.18999998535156
1,2.3396820655235877,0.6849766717720032,84.33399996582031,97.14000003417969
2,2.2849164650990414,0.7006625423812867,84.32999998779297,97.19599995605469
3,2.2638514316998997,0.6839149640274048,84.42799999023437,97.21000003417969
4,2.3003986065204325,0.6951746271324157,84.51999996582032,97.20600003417968
5,2.265346252001249,0.6716998846244812,84.4420000390625,97.20200003417969
6,2.2564402039234457,0.677104370727539,84.4259999633789,97.22399998535157
7,2.3008200640861807,0.6659402999305725,84.46000001464844,97.27000000976562
8,2.253789108533126,0.6801115174865723,84.54399996582032,97.21400003417969
9,2.3075229708965006,0.6605272066688538,84.46800001464844,97.1820000341797
10,2.284832917726957,0.6837751863861083,84.48199999023437,97.19200000976562
11,2.3022379279136658,0.6809691307640076,84.53400001464844,97.21200000976563
12,2.2808263347699094,0.6799604147911071,84.47399999267579,97.25200003417969
13,2.2778979677420397,0.6703613100624084,84.52200001953125,97.21200000976563
14,2.220526851140536,0.6813768936157226,84.49400001953126,97.22800003417969
15,2.236867505770463,0.6821543011665344,84.56800001953125,97.22800003417969
16,2.2702108713296743,0.6934137293052673,84.51199999023437,97.22400003417968
17,2.2874821562033434,0.7100320655632019,84.55599999023437,97.21800003417968
18,2.2137838762540083,0.6670947971534729,84.56599996582031,97.20400000976562
19,2.2960855869146495,0.673701826992035,84.54999999023437,97.22600003417969
20,2.260560466692998,0.6704543328857422,84.62000001464844,97.22000000976563
21,2.3083780178656945,0.6758274334907531,84.5859999658203,97.21600000976562
22,2.276798422519977,0.6762671936225891,84.54399996582032,97.19600000976563
23,2.2856121980226956,0.6660497138595581,84.61199996582032,97.17600000976563
24,2.2410703026331387,0.68424802734375,84.5979999658203,97.24800000976562
25,2.248056019728,0.65673936378479,84.59199999023437,97.23400000976562
26,2.2597031868421116,0.6829823554992676,84.56999999023438,97.21800003417968
27,2.2474928085620585,0.6944181496810913,84.5619999658203,97.23200003417969
28,2.255860697764617,0.6809109293937683,84.54399996582032,97.24600000976562
29,2.2558590265420766,0.6986223343276977,84.56799996337891,97.25800003417969
================================================
FILE: fine-tune.py
================================================
#!/usr/bin/env python3
""" ImageNet Training Script
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import time
# from types import _KT
import yaml
import os
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
from timm.utils import *
from timm.loss import *
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
# Model parameters
parser.add_argument('--model', default='deit_small_patch16_224', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
help='number of label classes (Model default if None)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='validation batch size override (default: None)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=2e-5,
help='weight decay (default: 2e-5)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
help='learning rate (default: 0.05)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
parser.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)')
parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-repeats', type=int, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
parser.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=0.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('--checkpoint-hist', type=int, default=1, metavar='N',
help='number of checkpoints to keep (default: 10)')
parser.add_argument('-j', '--workers', type=int, default=10, metavar='N',
help='how many training processes to use (default: 4)')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=True,
help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def load_init_checkpoint(model, checkpoint_path):
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
_logger.info('Restoring model state from checkpoint...')
new_state_dict = OrderedDict()
# model_state_dict = model.state_dict()
for k, v in checkpoint['state_dict'].items():
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
# model_state_dict[name] = v
model.load_state_dict(new_state_dict)
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
else:
model.load_state_dict(checkpoint)
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
else:
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
def main():
setup_default_logging()
args, args_text = _parse_args()
if args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
random_seed(args.seed, args.rank)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
model.eval()
flops = FlopCountAnalysis(model, torch.rand(1, 3, 384, 384))
if args.local_rank == 0:
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
print(flop_count_table(flops))
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
# model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.distributed and args.sync_bn:
assert not args.split_bn
if has_apex and use_amp == 'apex':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
if os.path.exists(args.resume):
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0,
)
if args.initial_checkpoint:
if os.path.exists(args.initial_checkpoint):
load_init_checkpoint(
model, args.initial_checkpoint,
)
# memory=student_mem
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEmaV2(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
# setup distributed training
if args.distributed:
if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets
dataset_train = create_dataset(
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
repeats=args.epoch_repeats)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size)
# setup mixup / cutmix
collate_fn = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
if collate_fn is not None:
print('collate_fn is not none')
if mixup_fn is not None:
print('mixup_fn is not none')
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeiine
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding,
)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)
# setup loss function
if args.jsd_loss:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform which outputs sparse, soft targets
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
else:
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
else:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = None
if args.rank == 0:
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
try:
entropy_thr = 0
for epoch in range(start_epoch, num_epochs):
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
train_metrics, all_entropy = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
entropy_thr=entropy_thr)
# all_entropy = torch.stack(all_entropy, dim=0)
# entropy_thr = all_entropy.mean()
# entropy_thr = entropy_thr * 2.0
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
if output_dir is not None:
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None, entropy_thr=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
model.train()
all_entropy = []
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
if mixup_fn is not None:
input_mix, target_mix = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input_mix)
# output = model(input)
loss = loss_fn(output, target_mix)
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step()
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)]), all_entropy
# utils
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
# print(output.size())
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m))
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics
if __name__ == '__main__':
main()
================================================
FILE: models/__init__.py
================================================
from .inception_transformer import *
================================================
FILE: models/inception_transformer.py
================================================
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inception transformer implementation.
Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models).
"""
import math
import logging
from functools import partial
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from timm.models.registry import register_model
from torch.nn.init import _calculate_fan_in_and_fan_out
import math
import warnings
from timm.models.layers.helpers import to_2tuple
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
# default_cfgs = {
# 'iformer_224': _cfg(),
# 'iformer_384': _cfg(input_size=(3, 384, 384), crop_pct=1.0),
# }
default_cfgs = {
'iformer_small': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_small.pth'),
'iformer_base': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_base.pth'),
'iformer_large': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_large.pth'),
'iformer_small_384': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_small_384.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'iformer_base_384': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_base_384.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'iformer_large_384': _cfg(url='https://huggingface.co/sail/dl2/resolve/main/iformer/iformer_large_384.pth',
input_size=(3, 384, 384), crop_pct=1.0),
}
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
denom = fan_in
elif mode == 'fan_out':
denom = fan_out
elif mode == 'fan_avg':
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, kernel_size=16, stride=16, padding=0, in_chans=3, embed_dim=768):
super().__init__()
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding )
self.norm = nn.BatchNorm2d(embed_dim)
def forward(self, x):
# B, C, H, W = x.shape
x = self.proj(x)
x = self.norm(x)
x = x.permute(0,2,3,1)
return x
class FirstPatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, kernel_size=3, stride=2, padding=1, in_chans=3, embed_dim=768):
super().__init__()
self.proj1 = nn.Conv2d(in_chans, embed_dim//2, kernel_size=kernel_size, stride=stride, padding=padding )
self.norm1 = nn.BatchNorm2d(embed_dim // 2)
self.gelu1 = nn.GELU()
self.proj2 = nn.Conv2d(embed_dim//2, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding )
self.norm2 = nn.BatchNorm2d(embed_dim)
def forward(self, x):
# B, C, H, W = x.shape
x = self.proj1(x)
x = self.norm1(x)
x = self.gelu1(x)
x = self.proj2(x)
x = self.norm2(x)
x = x.permute(0,2,3,1)
return x
class HighMixer(nn.Module):
def __init__(self, dim, kernel_size=3, stride=1, padding=1,
**kwargs, ):
super().__init__()
self.cnn_in = cnn_in = dim // 2
self.pool_in = pool_in = dim // 2
self.cnn_dim = cnn_dim = cnn_in * 2
self.pool_dim = pool_dim = pool_in * 2
self.conv1 = nn.Conv2d(cnn_in, cnn_dim, kernel_size=1, stride=1, padding=0, bias=False)
self.proj1 = nn.Conv2d(cnn_dim, cnn_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=False, groups=cnn_dim)
self.mid_gelu1 = nn.GELU()
self.Maxpool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding)
self.proj2 = nn.Conv2d(pool_in, pool_dim, kernel_size=1, stride=1, padding=0)
self.mid_gelu2 = nn.GELU()
def forward(self, x):
# B, C H, W
cx = x[:,:self.cnn_in,:,:].contiguous()
cx = self.conv1(cx)
cx = self.proj1(cx)
cx = self.mid_gelu1(cx)
px = x[:,self.cnn_in:,:,:].contiguous()
px = self.Maxpool(px)
px = self.proj2(px)
px = self.mid_gelu2(px)
hx = torch.cat((cx, px), dim=1)
return hx
class LowMixer(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., pool_size=2,
**kwargs, ):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.dim = dim
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.pool = nn.AvgPool2d(pool_size, stride=pool_size, padding=0, count_include_pad=False) if pool_size > 1 else nn.Identity()
self.uppool = nn.Upsample(scale_factor=pool_size) if pool_size > 1 else nn.Identity()
def att_fun(self, q, k, v, B, N, C):
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = (attn @ v).transpose(2, 3).reshape(B, C, N)
return x
def forward(self, x):
# B, C, H, W
B, _, _, _ = x.shape
xa = self.pool(x)
xa = xa.permute(0, 2, 3, 1).view(B, -1, self.dim)
B, N, C = xa.shape
qkv = self.qkv(xa).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
xa = self.att_fun(q, k, v, B, N, C)
xa = xa.view(B, C, int(N**0.5), int(N**0.5))#.permute(0, 3, 1, 2)
xa = self.uppool(xa)
return xa
class Mixer(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., attention_head=1, pool_size=2,
**kwargs, ):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim = dim // num_heads
self.low_dim = low_dim = attention_head * head_dim
self.high_dim = high_dim = dim - low_dim
self.high_mixer = HighMixer(high_dim)
self.low_mixer = LowMixer(low_dim, num_heads=attention_head, qkv_bias=qkv_bias, attn_drop=attn_drop, pool_size=pool_size,)
self.conv_fuse = nn.Conv2d(low_dim+high_dim*2, low_dim+high_dim*2, kernel_size=3, stride=1, padding=1, bias=False, groups=low_dim+high_dim*2)
self.proj = nn.Conv2d(low_dim+high_dim*2, dim, kernel_size=1, stride=1, padding=0)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, H, W, C = x.shape
x = x.permute(0, 3, 1, 2)
hx = x[:,:self.high_dim,:,:].contiguous()
hx = self.high_mixer(hx)
lx = x[:,self.high_dim:,:,:].contiguous()
lx = self.low_mixer(lx)
x = torch.cat((hx, lx), dim=1)
x = x + self.conv_fuse(x)
x = self.proj(x)
x = self.proj_drop(x)
x = x.permute(0, 2, 3, 1).contiguous()
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention_head=1, pool_size=2,
attn=Mixer,
use_layer_scale=False, layer_scale_init_value=1e-5,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, attention_head=attention_head, pool_size=pool_size,)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.use_layer_scale = use_layer_scale
if self.use_layer_scale:
# print('use layer scale init value {}'.format(layer_scale_init_value))
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class InceptionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=None, depths=None,
num_heads=None, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None, weight_init='',
attention_heads=None,
use_layer_scale=False, layer_scale_init_value=1e-5,
checkpoint_path=None,
**kwargs,
):
super().__init__()
st2_idx = sum(depths[:1])
st3_idx = sum(depths[:2])
st4_idx = sum(depths[:3])
depth = sum(depths)
self.num_classes = num_classes
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.patch_embed = FirstPatchEmbed(in_chans=in_chans, embed_dim=embed_dims[0])
self.num_patches1 = num_patches = img_size // 4
self.pos_embed1 = nn.Parameter(torch.zeros(1, num_patches, num_patches, embed_dims[0]))
self.blocks1 = nn.Sequential(*[
Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attention_head=attention_heads[i], pool_size=2,)
# use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value,
# )
for i in range(0, st2_idx)])
self.patch_embed2 = embed_layer(kernel_size=3, stride=2, padding=1, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.num_patches2 = num_patches = num_patches // 2
self.pos_embed2 = nn.Parameter(torch.zeros(1, num_patches, num_patches, embed_dims[1]))
self.blocks2 = nn.Sequential(*[
Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attention_head=attention_heads[i], pool_size=2,)
# use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, channel_layer_scale=channel_layer_scale,
# )
for i in range(st2_idx,st3_idx)])
self.patch_embed3 = embed_layer(kernel_size=3, stride=2, padding=1, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.num_patches3 = num_patches = num_patches // 2
self.pos_embed3 = nn.Parameter(torch.zeros(1, num_patches, num_patches, embed_dims[2]))
self.blocks3= nn.Sequential(*[
Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attention_head=attention_heads[i], pool_size=1,
use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value,
)
for i in range(st3_idx, st4_idx)])
self.patch_embed4 = embed_layer(kernel_size=3, stride=2, padding=1, in_chans=embed_dims[2], embed_dim=embed_dims[3])
self.num_patches4 = num_patches = num_patches // 2
self.pos_embed4 = nn.Parameter(torch.zeros(1, num_patches, num_patches, embed_dims[3]))
self.blocks4 = nn.Sequential(*[
Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attention_head=attention_heads[i], pool_size=1,
use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value,
)
for i in range(st4_idx,depth)])
self.norm = norm_layer(embed_dims[-1])
# Classifier head(s)
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
# set post block, for example, class attention layers
self.init_weights(weight_init)
def init_weights(self, mode=''):
trunc_normal_(self.pos_embed1, std=.02)
trunc_normal_(self.pos_embed2, std=.02)
trunc_normal_(self.pos_embed3, std=.02)
trunc_normal_(self.pos_embed4, std=.02)
self.apply(_init_vit_weights)
def _init_weights(self, m):
# this fn left here for compat with downstream users
_init_vit_weights(m)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'}
def get_classifier(self):
if self.dist_token is None:
return self.head
else:
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.num_tokens == 2:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def _get_pos_embed(self, pos_embed, num_patches_def, H, W):
if H * W == num_patches_def * num_patches_def:
return pos_embed
else:
return F.interpolate(
pos_embed.permute(0, 3, 1, 2),
size=(H, W), mode="bilinear").permute(0, 2, 3, 1)
def forward_features(self, x):
x = self.patch_embed(x)
B, H, W, C = x.shape
x = x + self._get_pos_embed(self.pos_embed1, self.num_patches1, H, W)
x = self.blocks1(x)
x = x.permute(0, 3, 1, 2)
x = self.patch_embed2(x)
B, H, W, C = x.shape
x = x + self._get_pos_embed(self.pos_embed2, self.num_patches2, H, W)
x = self.blocks2(x)
x = x.permute(0, 3, 1, 2)
x = self.patch_embed3(x)
B, H, W, C = x.shape
x = x + self._get_pos_embed(self.pos_embed3, self.num_patches3, H, W)
x = self.blocks3(x)
x = x.permute(0, 3, 1, 2)
x = self.patch_embed4(x)
B, H, W, C = x.shape
x = x + self._get_pos_embed(self.pos_embed4, self.num_patches4, H, W)
x = self.blocks4(x)
x = x.flatten(1,2)
x = self.norm(x)
return x.mean(1)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
""" ViT weight initialization
* When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
"""
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
elif name.startswith('pre_logits'):
lecun_normal_(module.weight)
nn.init.zeros_(module.bias)
else:
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
elif isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
@register_model
def iformer_small(pretrained=False, **kwargs):
"""
19.866M 4.849G 83.382
"""
depths = [3, 3, 9, 3]
embed_dims = [96, 192, 320, 384]
num_heads = [3, 6, 10, 12]
attention_heads = [1]*3 + [3]*3 + [7] * 4 + [9] * 5 + [11] * 3
model = InceptionTransformer(img_size=224,
depths=depths,
embed_dims=embed_dims,
num_heads=num_heads,
attention_heads=attention_heads,
use_layer_scale=True, layer_scale_init_value=1e-6,
**kwargs)
model.default_cfg = default_cfgs['iformer_small']
if pretrained:
url = model.default_cfg['url']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def iformer_small_384(pretrained=False, **kwargs):
depths = [3, 3, 9, 3]
embed_dims = [96, 192, 320, 384]
num_heads = [3, 6, 10, 12]
attention_heads = [1]*3 + [3]*3 + [7] * 4 + [9] * 5 + [11] * 3
model = InceptionTransformer(img_size=384,
depths=depths,
embed_dims=embed_dims,
num_heads=num_heads,
attention_heads=attention_heads,
use_layer_scale=True, layer_scale_init_value=1e-6,
**kwargs)
model.default_cfg = default_cfgs['iformer_small_384']
if pretrained:
url = model.default_cfg['url']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def iformer_base(pretrained=False, **kwargs):
"""
47.866M 9.379G 84.598
"""
depths = [4, 6, 14, 6]
embed_dims = [96, 192, 384, 512]
num_heads = [3, 6, 12, 16]
attention_heads = [1]*4 + [3]*6 + [8] * 7 + [10] * 7 + [15] * 6
model = InceptionTransformer(img_size=224,
depths=depths,
embed_dims=embed_dims,
num_heads=num_heads,
attention_heads=attention_heads,
use_layer_scale=True, layer_scale_init_value=1e-6,
**kwargs)
model.default_cfg = default_cfgs['iformer_base']
if pretrained:
url = model.default_cfg['url']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def iformer_base_384(pretrained=False, **kwargs):
depths = [4, 6, 14, 6]
embed_dims = [96, 192, 384, 512]
num_heads = [3, 6, 12, 16]
attention_heads = [1]*4 + [3]*6 + [8] * 7 + [10] * 7 + [15] * 6
model = InceptionTransformer(img_size=384,
depths=depths,
embed_dims=embed_dims,
num_heads=num_heads,
attention_heads=attention_heads,
use_layer_scale=True, layer_scale_init_value=1e-6,
**kwargs)
model.default_cfg = default_cfgs['iformer_base_384']
if pretrained:
url = model.default_cfg['url']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def iformer_large(pretrained=False, **kwargs):
"""
86.637M 14.048G 84.752
"""
depths = [4, 6, 18, 8]
embed_dims = [96, 192, 448, 640]
num_heads = [3, 6, 14, 20]
attention_heads = [1]*4 + [3]*6 + [10] * 9 + [12] * 9 + [19] * 8
model = InceptionTransformer(img_size=224,
depths=depths,
embed_dims=embed_dims,
num_heads=num_heads,
attention_heads=attention_heads,
use_layer_scale=True, layer_scale_init_value=1e-6,
**kwargs)
model.default_cfg = default_cfgs['iformer_large']
if pretrained:
url = model.default_cfg['url']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
@register_model
def iformer_large_384(pretrained=False, **kwargs):
depths = [4, 6, 18, 8]
embed_dims = [96, 192, 448, 640]
num_heads = [3, 6, 14, 20]
attention_heads = [1]*4 + [3]*6 + [10] * 9 + [12] * 9 + [19] * 8
model = InceptionTransformer(img_size=384,
depths=depths,
embed_dims=embed_dims,
num_heads=num_heads,
attention_heads=attention_heads,
use_layer_scale=True, layer_scale_init_value=1e-6,
**kwargs)
model.default_cfg = default_cfgs['iformer_large_384']
if pretrained:
url = model.default_cfg['url']
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
model.load_state_dict(checkpoint)
return model
================================================
FILE: setup.cfg
================================================
[dist_conda]
conda_name_differences = 'torch:pytorch'
channels = pytorch
noarch = True
================================================
FILE: train.py
================================================
#!/usr/bin/env python3
""" ImageNet Training Script
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import time
# from types import _KT
import yaml
import os
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
from timm.utils import *
from timm.loss import *
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
import models
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# ============self parameters==========================
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
# Model parameters
parser.add_argument('--model', default='deit_small_patch16_224', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
help='number of label classes (Model default if None)')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='validation batch size override (default: None)')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=2e-5,
help='weight decay (default: 2e-5)')
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
help='learning rate (default: 0.05)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
parser.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)')
parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters
parser.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-repeats', type=int, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
parser.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
parser.add_argument('--bn-tf', action='store_true', default=False,
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
parser.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
parser.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('--checkpoint-hist', type=int, default=1, metavar='N',
help='number of checkpoints to keep (default: 10)')
parser.add_argument('-j', '--workers', type=int, default=10, metavar='N',
help='how many training processes to use (default: 4)')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=True,
help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
setup_default_logging()
args, args_text = _parse_args()
if args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
else:
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.nati
gitextract_p3kgg1_k/ ├── LICENSE ├── MANIFEST.in ├── README.md ├── checkpoint/ │ ├── iformer_base/ │ │ ├── args.yaml │ │ └── summary.csv │ ├── iformer_large/ │ │ ├── args.yaml │ │ └── summary.csv │ └── iformer_small/ │ ├── args.yaml │ └── summary.csv ├── checkpoint_384/ │ ├── iformer_base_384/ │ │ ├── args.yaml │ │ └── summary.csv │ ├── iformer_large_384/ │ │ ├── args.yaml │ │ └── summary.csv │ └── iformer_small_384/ │ ├── args.yaml │ └── summary.csv ├── fine-tune.py ├── models/ │ ├── __init__.py │ └── inception_transformer.py ├── setup.cfg ├── train.py └── validate.py
SYMBOL INDEX (55 symbols across 4 files)
FILE: fine-tune.py
function _parse_args (line 310) | def _parse_args():
function load_init_checkpoint (line 327) | def load_init_checkpoint(model, checkpoint_path):
function main (line 350) | def main():
function train_one_epoch (line 713) | def train_one_epoch(
function concat_all_gather (line 824) | def concat_all_gather(tensor):
function validate (line 837) | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_su...
FILE: models/inception_transformer.py
function _cfg (line 42) | def _cfg(url='', **kwargs):
function _no_grad_trunc_normal_ (line 72) | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
function trunc_normal_ (line 108) | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
function variance_scaling_ (line 130) | def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='no...
function lecun_normal_ (line 153) | def lecun_normal_(tensor):
class PatchEmbed (line 156) | class PatchEmbed(nn.Module):
method __init__ (line 159) | def __init__(self, img_size=224, kernel_size=16, stride=16, padding=0...
method forward (line 169) | def forward(self, x):
class FirstPatchEmbed (line 176) | class FirstPatchEmbed(nn.Module):
method __init__ (line 179) | def __init__(self, kernel_size=3, stride=2, padding=1, in_chans=3, em...
method forward (line 188) | def forward(self, x):
class HighMixer (line 198) | class HighMixer(nn.Module):
method __init__ (line 199) | def __init__(self, dim, kernel_size=3, stride=1, padding=1,
method forward (line 217) | def forward(self, x):
class LowMixer (line 233) | class LowMixer(nn.Module):
method __init__ (line 234) | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., poo...
method att_fun (line 249) | def att_fun(self, q, k, v, B, N, C):
method forward (line 257) | def forward(self, x):
class Mixer (line 271) | class Mixer(nn.Module):
method __init__ (line 272) | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., pro...
method forward (line 289) | def forward(self, x):
class Block (line 306) | class Block(nn.Module):
method __init__ (line 308) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=...
method forward (line 330) | def forward(self, x):
class InceptionTransformer (line 339) | class InceptionTransformer(nn.Module):
method __init__ (line 340) | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classe...
method init_weights (line 415) | def init_weights(self, mode=''):
method _init_weights (line 423) | def _init_weights(self, m):
method no_weight_decay (line 429) | def no_weight_decay(self):
method get_classifier (line 432) | def get_classifier(self):
method reset_classifier (line 438) | def reset_classifier(self, num_classes, global_pool=''):
method _get_pos_embed (line 444) | def _get_pos_embed(self, pos_embed, num_patches_def, H, W):
method forward_features (line 452) | def forward_features(self, x):
method forward (line 480) | def forward(self, x):
function _init_vit_weights (line 485) | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: floa...
function iformer_small (line 512) | def iformer_small(pretrained=False, **kwargs):
function iformer_small_384 (line 536) | def iformer_small_384(pretrained=False, **kwargs):
function iformer_base (line 560) | def iformer_base(pretrained=False, **kwargs):
function iformer_base_384 (line 584) | def iformer_base_384(pretrained=False, **kwargs):
function iformer_large (line 606) | def iformer_large(pretrained=False, **kwargs):
function iformer_large_384 (line 630) | def iformer_large_384(pretrained=False, **kwargs):
FILE: train.py
function _parse_args (line 314) | def _parse_args():
function main (line 331) | def main():
function train_one_epoch (line 687) | def train_one_epoch(
function concat_all_gather (line 797) | def concat_all_gather(tensor):
function validate (line 810) | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_su...
FILE: validate.py
function validate (line 123) | def validate(args):
function main (line 297) | def main():
function write_results (line 361) | def write_results(results_file, results):
Condensed preview — 21 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (242K chars).
[
{
"path": "LICENSE",
"chars": 11343,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "MANIFEST.in",
"chars": 34,
"preview": "include timm/models/pruned/*.txt\n\n"
},
{
"path": "README.md",
"chars": 5481,
"preview": "# iFormer: [Inception Transformer](http://arxiv.org/abs/2205.12956) (NeurIPS 2022 Oral)\nThis is a PyTorch implementation"
},
{
"path": "checkpoint/iformer_base/args.yaml",
"chars": 1871,
"preview": "aa: rand-m9-mstd0.5-inc1\namp: false\napex_amp: false\naug_repeats: 3\naug_splits: 0\nbatch_size: 64\nbce_loss: false\nbce_targ"
},
{
"path": "checkpoint/iformer_base/summary.csv",
"chars": 24038,
"preview": "epoch,train_loss,eval_loss,eval_top1,eval_top5\r\n0,6.908455812014067,6.860800276947021,0.3619999967956543,1.5300000119018"
},
{
"path": "checkpoint/iformer_large/args.yaml",
"chars": 1872,
"preview": "aa: rand-m9-mstd0.5-inc1\namp: false\napex_amp: true\naug_repeats: 3\naug_splits: 0\nbatch_size: 64\nbce_loss: false\nbce_targe"
},
{
"path": "checkpoint/iformer_large/summary.csv",
"chars": 24106,
"preview": "epoch,train_loss,eval_loss,eval_top1,eval_top5\r\n0,6.910546504534208,6.850864137573242,0.4499999999666214,1.8380000002288"
},
{
"path": "checkpoint/iformer_small/args.yaml",
"chars": 1870,
"preview": "aa: rand-m9-mstd0.5-inc1\namp: false\napex_amp: false\naug_repeats: 0\naug_splits: 0\nbatch_size: 64\nbce_loss: false\nbce_targ"
},
{
"path": "checkpoint/iformer_small/summary.csv",
"chars": 24023,
"preview": "epoch,train_loss,eval_loss,eval_top1,eval_top5\r\n0,6.9094244333413934,6.864997950286865,0.3860000048828125,1.522000004882"
},
{
"path": "checkpoint_384/iformer_base_384/args.yaml",
"chars": 1926,
"preview": "aa: rand-m9-mstd0.5-inc1\namp: false\napex_amp: true\naug_repeats: 3\naug_splits: 0\nbatch_size: 32\nbce_loss: false\nbce_targe"
},
{
"path": "checkpoint_384/iformer_base_384/summary.csv",
"chars": 1683,
"preview": "epoch,train_loss,eval_loss,eval_top1,eval_top5\r\n0,2.162344753742218,0.6423477774047851,85.416,97.554\r\n1,2.14225176152061"
},
{
"path": "checkpoint_384/iformer_large_384/args.yaml",
"chars": 1929,
"preview": "aa: rand-m9-mstd0.5-inc1\namp: false\napex_amp: true\naug_repeats: 3\naug_splits: 0\nbatch_size: 32\nbce_loss: false\nbce_targe"
},
{
"path": "checkpoint_384/iformer_large_384/summary.csv",
"chars": 2357,
"preview": "epoch,train_loss,eval_loss,eval_top1,eval_top5\r\n0,2.1128095227938433,0.6339371264839172,85.65200001953124,97.53999996093"
},
{
"path": "checkpoint_384/iformer_small_384/args.yaml",
"chars": 1936,
"preview": "aa: rand-m9-mstd0.5-inc1\namp: false\napex_amp: false\naug_repeats: 0\naug_splits: 0\nbatch_size: 32\nbce_loss: false\nbce_targ"
},
{
"path": "checkpoint_384/iformer_small_384/summary.csv",
"chars": 2354,
"preview": "epoch,train_loss,eval_loss,eval_top1,eval_top5\r\n0,2.3478839855927687,0.6923815357589722,84.32999999023437,97.18999998535"
},
{
"path": "fine-tune.py",
"chars": 43237,
"preview": "#!/usr/bin/env python3\n\"\"\" ImageNet Training Script\n\nThis is intended to be a lean and easily modifiable ImageNet traini"
},
{
"path": "models/__init__.py",
"chars": 37,
"preview": "from .inception_transformer import *\n"
},
{
"path": "models/inception_transformer.py",
"chars": 25810,
"preview": "# Copyright 2022 Garena Online Private Limited\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you"
},
{
"path": "setup.cfg",
"chars": 88,
"preview": "[dist_conda]\n\nconda_name_differences = 'torch:pytorch'\nchannels = pytorch\nnoarch = True\n"
},
{
"path": "train.py",
"chars": 42065,
"preview": "#!/usr/bin/env python3\n\"\"\" ImageNet Training Script\n\nThis is intended to be a lean and easily modifiable ImageNet traini"
},
{
"path": "validate.py",
"chars": 16270,
"preview": "#!/usr/bin/env python3\n\"\"\" ImageNet Validation Script\n\nThis is intended to be a lean and easily modifiable ImageNet vali"
}
]
About this extraction
This page contains the full source code of the sail-sg/iFormer GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 21 files (228.8 KB), approximately 74.9k tokens, and a symbol index with 55 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.