master 8a6775c8267c cached
103 files
610.4 KB
162.8k tokens
986 symbols
1 requests
Download .txt
Showing preview only (643K chars total). Download the full file or copy to clipboard to get everything.
Repository: blackfeather-wang/GFNet-Pytorch
Branch: master
Commit: 8a6775c8267c
Files: 103
Total size: 610.4 KB

Directory structure:
gitextract_6lcfcbge/

├── .gitignore
├── LICENSE
├── README.md
├── configs.py
├── inference.py
├── models/
│   ├── __init__.py
│   ├── activations/
│   │   ├── __init__.py
│   │   ├── activations.py
│   │   ├── activations_autofn.py
│   │   ├── activations_jit.py
│   │   └── config.py
│   ├── config.py
│   ├── conv2d_layers.py
│   ├── densenet.py
│   ├── efficientnet_builder.py
│   ├── gen_efficientnet.py
│   ├── helpers.py
│   ├── mobilenetv3.py
│   ├── model_factory.py
│   ├── resnet.py
│   └── version.py
├── network.py
├── pycls/
│   ├── __init__.py
│   ├── cfgs/
│   │   ├── RegNetY-1.6GF_dds_8gpu.yaml
│   │   ├── RegNetY-600MF_dds_8gpu.yaml
│   │   └── RegNetY-800MF_dds_8gpu.yaml
│   ├── core/
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── losses.py
│   │   ├── model_builder.py
│   │   ├── old_config.py
│   │   └── optimizer.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── cifar10.py
│   │   ├── imagenet.py
│   │   ├── loader.py
│   │   ├── paths.py
│   │   └── transforms.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── anynet.py
│   │   ├── effnet.py
│   │   ├── regnet.py
│   │   └── resnet.py
│   └── utils/
│       ├── __init__.py
│       ├── benchmark.py
│       ├── checkpoint.py
│       ├── distributed.py
│       ├── error_handler.py
│       ├── io.py
│       ├── logging.py
│       ├── lr_policy.py
│       ├── meters.py
│       ├── metrics.py
│       ├── multiprocessing.py
│       ├── net.py
│       ├── plotting.py
│       └── timer.py
├── simplejson/
│   ├── __init__.py
│   ├── _speedups.c
│   ├── compat.py
│   ├── decoder.py
│   ├── encoder.py
│   ├── errors.py
│   ├── ordered_dict.py
│   ├── raw_json.py
│   ├── scanner.py
│   ├── tests/
│   │   ├── __init__.py
│   │   ├── test_bigint_as_string.py
│   │   ├── test_bitsize_int_as_string.py
│   │   ├── test_check_circular.py
│   │   ├── test_decimal.py
│   │   ├── test_decode.py
│   │   ├── test_default.py
│   │   ├── test_dump.py
│   │   ├── test_encode_basestring_ascii.py
│   │   ├── test_encode_for_html.py
│   │   ├── test_errors.py
│   │   ├── test_fail.py
│   │   ├── test_float.py
│   │   ├── test_for_json.py
│   │   ├── test_indent.py
│   │   ├── test_item_sort_key.py
│   │   ├── test_iterable.py
│   │   ├── test_namedtuple.py
│   │   ├── test_pass1.py
│   │   ├── test_pass2.py
│   │   ├── test_pass3.py
│   │   ├── test_raw_json.py
│   │   ├── test_recursion.py
│   │   ├── test_scanstring.py
│   │   ├── test_separators.py
│   │   ├── test_speedups.py
│   │   ├── test_str_subclass.py
│   │   ├── test_subclass.py
│   │   ├── test_tool.py
│   │   ├── test_tuple.py
│   │   └── test_unicode.py
│   └── tool.py
├── train.py
├── utils.py
└── yacs/
    ├── __init__.py
    ├── config.py
    └── tests.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================

models/.DS_Store
figures/.DS_Store


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# Glance-and-Focus Networks (PyTorch)

This repo contains the official code and pre-trained models for the glance and focus networks (GFNet).

- (NeurIPS 2020) [Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification](https://arxiv.org/abs/2010.05300)
- (T-PAMI) [Glance and Focus Networks for Dynamic Visual Recognition](https://arxiv.org/abs/2201.03014)

**Update on 2020/12/28: Release Training Code.**

**Update on 2020/10/08: Release Pre-trained Models and the Inference Code on ImageNet.**

## Introduction

<p align="center">
    <img src="figures/examples.png" width= "420">
</p>

Inspired by the fact that not all regions in an image are task-relevant, we propose a novel framework that performs efficient image classification by processing a sequence of relatively small inputs, which are strategically cropped from the original image.
Experiments on ImageNet show that our method consistently improves the computational efficiency of a wide variety of deep models.
For example, it further reduces the average latency of the highly efficient MobileNet-V3 on an iPhone XS Max by 20% without sacrificing accuracy.
<p align="center">
    <img src="figures/overview.png" width= "810">
</p>



## Citation

```
@inproceedings{NeurIPS2020_7866,
        title={Glance and Focus: a Dynamic Approach to Reducing Spatial Redundancy in Image Classification},
        author={Wang, Yulin and Lv, Kangchen and Huang, Rui and Song, Shiji and Yang, Le and Huang, Gao},
        booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
        year={2020},
}

@article{huang2023glance,
        title={Glance and Focus Networks for Dynamic Visual Recognition}, 
        author={Huang, Gao and Wang, Yulin and Lv, Kangchen and Jiang, Haojun and Huang, Wenhui and Qi, Pengfei and Song, Shiji},
        journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 
        year={2023},
        volume={45},
        number={4},
        pages={4605-4621},
        doi={10.1109/TPAMI.2022.3196959}
}
```



## Results

- Top-1 accuracy on ImageNet v.s. Multiply-Adds
<p align="center">
    <img src="figures/result_main.png" width= "810">
</p>

- Top-1 accuracy on ImageNet v.s. Inference Latency (ms) on an iPhone XS Max
<p align="center">
    <img src="figures/result_speed.png" width= "540">
</p>


- Visualization
<p align="center">
    <img src="figures/result_visual.png" width= "810">
</p>


## Pre-trained Models


|Backbone CNNs|Patch Size|T|Links|
|-----|------|-----|-----|
|ResNet-50| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4c55dd9472b4416cbdc9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Iun8o4o7cQL-7vSwKyNfefOgwb9-o9kD/view?usp=sharing)|
|ResNet-50| 128x128| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1cbed71346e54a129771/?dl=1) / [Google Drive](https://drive.google.com/file/d/1cEj0dXO7BfzQNd5fcYZOQekoAe3_DPia/view?usp=sharing)|
|DenseNet-121| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c75c77d2f2054872ac20/?dl=1) / [Google Drive](https://drive.google.com/file/d/1UflIM29Npas0rTQSxPqwAT6zHbFkQq6R/view?usp=sharing)|
|DenseNet-169| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/83fef24e667b4dccace1/?dl=1) / [Google Drive](https://drive.google.com/file/d/1pBo22i6VsWJWtw2xJw1_bTSMO3HgJNDL/view?usp=sharing)|
|DenseNet-201| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/85a57b82e592470892e0/?dl=1) / [Google Drive](https://drive.google.com/file/d/1sETDr7dP5Q525fRMTIl2jlt9Eg2qh2dx/view?usp=sharing)|
|RegNet-Y-600MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/2638f038d3b1465da59e/?dl=1) / [Google Drive](https://drive.google.com/file/d/1FCR14wUiNrIXb81cU1pDPcy4bPSRXrqe/view?usp=sharing)|
|RegNet-Y-800MF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/686e411e72894b789dde/?dl=1) / [Google Drive](https://drive.google.com/file/d/1On39MwbJY5Zagz7gNtKBFwMWhhHskfZq/view?usp=sharing)|
|RegNet-Y-1.6GF| 96x96| 5|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/90116ad21ee74843b0ef/?dl=1) / [Google Drive](https://drive.google.com/file/d/1rMe0LU8m4BF3udII71JT0VLPBE2G-eCJ/view?usp=sharing)|
|MobileNet-V3-Large (1.00)| 96x96| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/4a4e8486b83b4dbeb06c/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Dw16jPlw2hR8EaWbd_1Ujd6Df9n1gHgj/view?usp=sharing)|
|MobileNet-V3-Large (1.00)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/ab0f6fc3997d4771a4c9/?dl=1) / [Google Drive](https://drive.google.com/file/d/1Ud_olyer-YgAb667YUKs38C2O1Yb6Nom/view?usp=sharing)|
|MobileNet-V3-Large (1.25)| 128x128| 3|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b2052c3af7734f688bc7/?dl=1) / [Google Drive](https://drive.google.com/file/d/14zj1Ci0i4nYceu-f2ZckFMjRmYDGtJpl/view?usp=sharing)|
|EfficientNet-B2| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/1a490deecd34470580da/?dl=1) / [Google Drive](https://drive.google.com/file/d/1LBBPrrYZzKKqCnoZH1kPfQQZ5ixkEmjz/view?usp=sharing)|
|EfficientNet-B3| 128x128| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/d5182a2257bb481ea622/?dl=1) / [Google Drive](https://drive.google.com/file/d/1fdxwimcuQAXBOsbOdGw8Ee43PgeHHZTA/view?usp=sharing)|
|EfficientNet-B3| 144x144| 4|[Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/f96abfb6de13430aa663/?dl=1) / [Google Drive](https://drive.google.com/file/d/1OVTGI6d2nsN5Hz5T_qLnYeUIBL5oMeeU/view?usp=sharing)|

- What are contained in the checkpoints:

```
**.pth.tar
├── model_name: name of the backbone CNNs (e.g., resnet50, densenet121)
├── patch_size: size of image patches (i.e., H' or W' in the paper)
├── model_prime_state_dict, model_state_dict, fc, policy: state dictionaries of the four components of GFNets
├── model_flops, policy_flops, fc_flops: Multiply-Adds of inferring the encoder, patch proposal network and classifier for once
├── flops: a list containing the Multiply-Adds corresponding to each length of the input sequence during inference
├── anytime_classification: results of anytime prediction (in Top-1 accuracy)
├── dynamic_threshold: the confidence thresholds used in budgeted batch classification
├── budgeted_batch_classification: results of budgeted batch classification (a two-item list, [0] and [1] correspond to the two coordinates of a curve)

```

## Requirements
- python 3.7.7
- pytorch 1.3.1
- torchvision 0.4.2
- pyyaml 5.3.1 (for RegNets)

## Evaluate Pre-trained Models

Read the evaluation results saved in pre-trained models
```
CUDA_VISIBLE_DEVICES=0 python inference.py --checkpoint_path PATH_TO_CHECKPOINTS  --eval_mode 0
```

Read the confidence thresholds saved in pre-trained models and infer the model on the validation set
```
CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS  --eval_mode 1
```

Determine confidence thresholds on the training set and infer the model on the validation set
```
CUDA_VISIBLE_DEVICES=0 python inference.py --data_url PATH_TO_DATASET --checkpoint_path PATH_TO_CHECKPOINTS  --eval_mode 2
```

The dataset is expected to be prepared as follows:
```
ImageNet
├── train
│   ├── folder 1 (class 1)
│   ├── folder 2 (class 1)
│   ├── ...
├── val
│   ├── folder 1 (class 1)
│   ├── folder 2 (class 1)
│   ├── ...

```


## Training

- Here we take training ResNet-50 (96x96, T=5) for example. All the used initialization models and stage-1/2 checkpoints can be found in [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/d/ac7c47b3f9b04e098862/) / [Google Drive](https://drive.google.com/drive/folders/1yO2GviOnukSUgcTkptNLBBttSJQZk9yn?usp=sharing). Currently, this link includes ResNet and MobileNet-V3. We will update it as soon as possible. If you need other helps, feel free to contact us.

- The Results in the paper is based on 2 Tesla V100 GPUs. For most of experiments, up to 4 Titan Xp GPUs may be enough.

Training stage 1, the initializations of global encoder (model_prime) and local encoder (model) are required:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 1 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --model_prime_path PATH_TO_CHECKPOINTS  --model_path PATH_TO_CHECKPOINTS
```

Training stage 2, a stage-1 checkpoint is required:
```
CUDA_VISIBLE_DEVICES=0 python train.py --data_url PATH_TO_DATASET --train_stage 2 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS
```

Training stage 3, a stage-2 checkpoint is required:
```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --data_url PATH_TO_DATASET --train_stage 3 --model_arch resnet50 --patch_size 96 --T 5 --print_freq 10 --checkpoint_path PATH_TO_CHECKPOINTS
```

## Contact
If you have any question, please feel free to contact the authors. Yulin Wang: wang-yl19@mails.tsinghua.edu.cn.

## Acknowledgment
Our code of MobileNet-V3 and EfficientNet is from [here](https://github.com/rwightman/pytorch-image-models). Our code of RegNet is from [here](https://github.com/facebookresearch/pycls).

## To Do
- Update the code for visualizing.

- Update the code for MIXED PRECISION TRAINING。



================================================
FILE: configs.py
================================================
from PIL import Image

model_configurations = {
    'resnet50': {
        'feature_num': 2048,
        'feature_map_channels': 2048,
        'policy_conv': False,
        'policy_hidden_dim': 1024,
        'fc_rnn': True,
        'fc_hidden_dim': 1024,
        'image_size': 224,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BILINEAR,
        'prime_interpolation': 'bicubic'
    },
    'densenet121': {
        'feature_num': 1024,
        'feature_map_channels': 1024,
        'policy_conv': False,
        'policy_hidden_dim': 1024,
        'fc_rnn': True,
        'fc_hidden_dim': 1024,
        'image_size': 224,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BILINEAR,
        'prime_interpolation': 'bilinear'
    },
    'densenet169': {
        'feature_num': 1664,
        'feature_map_channels': 1664,
        'policy_conv': False,
        'policy_hidden_dim': 1024,
        'fc_rnn': True,
        'fc_hidden_dim': 1024,
        'image_size': 224,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BILINEAR,
        'prime_interpolation': 'bilinear'
    },
    'densenet201': {
        'feature_num': 1920,
        'feature_map_channels': 1920,
        'policy_conv': False,
        'policy_hidden_dim': 1024,
        'fc_rnn': True,
        'fc_hidden_dim': 1024,
        'image_size': 224,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BILINEAR,
        'prime_interpolation': 'bilinear'
    },
    'mobilenetv3_large_100': {
        'feature_num': 1280,
        'feature_map_channels': 960,
        'policy_conv': True,
        'policy_hidden_dim': 256,      
        'fc_rnn': False,
        'fc_hidden_dim': None,
        'image_size': 224,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BILINEAR,
        'prime_interpolation': 'bicubic'
    },
    'mobilenetv3_large_125': {
        'feature_num': 1280,
        'feature_map_channels': 1200,
        'policy_conv': True,
        'policy_hidden_dim': 256,      
        'fc_rnn': False,
        'fc_hidden_dim': None,
        'image_size': 224,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BILINEAR,
        'prime_interpolation': 'bicubic'
    },
    'efficientnet_b2': {
        'feature_num': 1408,
        'feature_map_channels': 1408,
        'policy_conv': True,
        'policy_hidden_dim': 256,      
        'fc_rnn': False,
        'fc_hidden_dim': None,
        'image_size': 260,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BICUBIC,
        'prime_interpolation': 'bicubic'
    },
    'efficientnet_b3': {
        'feature_num': 1536,
        'feature_map_channels': 1536,
        'policy_conv': True,
        'policy_hidden_dim': 256,      
        'fc_rnn': False,
        'fc_hidden_dim': None,
        'image_size': 300,
        'crop_pct': 0.904,
        'dataset_interpolation': Image.BICUBIC,
        'prime_interpolation': 'bicubic'
    },
    'regnety_600m': {
        'feature_num': 608,
        'feature_map_channels': 608,
        'policy_conv': True,
        'policy_hidden_dim': 256,      
        'fc_rnn': True,
        'fc_hidden_dim': 1024,
        'image_size': 224,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BILINEAR,
        'prime_interpolation': 'bilinear',
        'cfg_file': 'pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml'
    },
    'regnety_800m': {
        'feature_num': 768,
        'feature_map_channels': 768,
        'policy_conv': True,
        'policy_hidden_dim': 256,      
        'fc_rnn': True,
        'fc_hidden_dim': 1024,
        'image_size': 224,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BILINEAR,
        'prime_interpolation': 'bilinear',
        'cfg_file': 'pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml'
    },
    'regnety_1.6g': {
        'feature_num': 888,
        'feature_map_channels': 888,
        'policy_conv': True,
        'policy_hidden_dim': 256,      
        'fc_rnn': True,
        'fc_hidden_dim': 1024,
        'image_size': 224,
        'crop_pct': 0.875,
        'dataset_interpolation': Image.BILINEAR,
        'prime_interpolation': 'bilinear',
        'cfg_file': 'pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml'
    }
}


train_configurations = {
    'resnet': {
        'backbone_lr': 0.01,
        'fc_stage_1_lr': 0.1,
        'fc_stage_3_lr': 0.01,
        'weight_decay': 1e-4,
        'momentum': 0.9,
        'Nesterov': True,
        'batch_size': 256,
        'dsn_ratio': 1,
        'epoch_num': 60,
        'train_model_prime': True
    },
    'densenet': {
        'backbone_lr': 0.01,
        'fc_stage_1_lr': 0.1,
        'fc_stage_3_lr': 0.01,
        'weight_decay': 1e-4,
        'momentum': 0.9,
        'Nesterov': True,
        'batch_size': 256,
        'dsn_ratio': 1,
        'epoch_num': 60,
        'train_model_prime': True
    },
    'efficientnet': {
        'backbone_lr': 0.005,
        'fc_stage_1_lr': 0.1,
        'fc_stage_3_lr': 0.01,
        'weight_decay': 1e-4,
        'momentum': 0.9,
        'Nesterov': True,
        'batch_size': 256,
        'dsn_ratio': 5,
        'epoch_num': 30,
        'train_model_prime': False
    },
    'mobilenetv3': {
        'backbone_lr': 0.005,
        'fc_stage_1_lr': 0.1,
        'fc_stage_3_lr': 0.01,
        'weight_decay': 1e-4,
        'momentum': 0.9,
        'Nesterov': True,
        'batch_size': 256,
        'dsn_ratio': 5,
        'epoch_num': 90,
        'train_model_prime': False
    },
    'regnet': {
        'backbone_lr': 0.02,
        'fc_stage_1_lr': 0.1,
        'fc_stage_3_lr': 0.01,
        'weight_decay': 5e-5,
        'momentum': 0.9,
        'Nesterov': True,
        'batch_size': 256,
        'dsn_ratio': 1,
        'epoch_num': 60,
        'train_model_prime': True
    }
}


================================================
FILE: inference.py
================================================

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

from utils import *
from network import *
from configs import *

import math
import argparse

import models.resnet as resnet
import models.densenet as densenet
from models import create_model


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description='Inference code for GFNet')

parser.add_argument('--data_url', default='./data', type=str,
                    help='path to the dataset (ImageNet)')

parser.add_argument('--checkpoint_path', default='', type=str,
                    help='path to the pre-train model (default: none)')

parser.add_argument('--eval_mode', default=2, type=int,
                    help='mode 0 : read the evaluation results saved in pre-trained models\
                          mode 1 : read the confidence thresholds saved in pre-trained models and infer the model on the validation set\
                          mode 2 : determine confidence thresholds on the training set and infer the model on the validation set')

args = parser.parse_args()


def main():
    # load pretrained model
    checkpoint = torch.load(args.checkpoint_path)

    try:
        model_arch = checkpoint['model_name']
        patch_size = checkpoint['patch_size']
        prime_size = checkpoint['patch_size']
        flops = checkpoint['flops']
        model_flops = checkpoint['model_flops']
        policy_flops = checkpoint['policy_flops']
        fc_flops = checkpoint['fc_flops']
        anytime_classification = checkpoint['anytime_classification']
        budgeted_batch_classification = checkpoint['budgeted_batch_classification']
        dynamic_threshold = checkpoint['dynamic_threshold']
        maximum_length = len(checkpoint['flops'])
    except:
        print('Error: \n'
              'Please provide essential information'
              'for customized models (as we have done '
              'in pre-trained models)!\n'
              'At least the following information should be Given: \n'
              '--model_name: name of the backbone CNNs (e.g., resnet50, densenet121)\n'
              '--patch_size: size of image patches (i.e., H\' or W\' in the paper)\n'
              '--flops: a list containing the Multiply-Adds corresponding to each '
              'length of the input sequence during inference')

    model_configuration = model_configurations[model_arch]

    if args.eval_mode > 0:
        # create model
        if 'resnet' in model_arch:
            model = resnet.resnet50(pretrained=False)
            model_prime = resnet.resnet50(pretrained=False)            

        elif 'densenet' in model_arch:
            model = eval('densenet.' + model_arch)(pretrained=False)
            model_prime = eval('densenet.' + model_arch)(pretrained=False)

        elif 'efficientnet' in model_arch:
            model = create_model(model_arch, pretrained=False, num_classes=1000,
                                    drop_rate=0.3, drop_connect_rate=0.2)
            model_prime = create_model(model_arch, pretrained=False, num_classes=1000,
                                    drop_rate=0.3, drop_connect_rate=0.2)
        
        elif 'mobilenetv3' in model_arch:
            model = create_model(model_arch, pretrained=False, num_classes=1000,
                                    drop_rate=0.2, drop_connect_rate=0.2)
            model_prime = create_model(model_arch, pretrained=False, num_classes=1000,
                                    drop_rate=0.2, drop_connect_rate=0.2)
        
        elif 'regnet' in model_arch:
            import pycls.core.model_builder as model_builder
            from pycls.core.config import cfg
            cfg.merge_from_file(model_configuration['cfg_file'])
            cfg.freeze()

            model = model_builder.build_model()
            model_prime = model_builder.build_model()

        traindir = args.data_url + 'train/'
        valdir = args.data_url + 'val/'

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_set = datasets.ImageFolder(traindir, transforms.Compose([
                transforms.RandomResizedCrop(model_configuration['image_size'], interpolation=model_configuration['dataset_interpolation']),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize ]))
        train_set_index = torch.randperm(len(train_set))
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, num_workers=32, pin_memory=False,
                sampler=torch.utils.data.sampler.SubsetRandomSampler(train_set_index[-200000:]))

        val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose([
                transforms.Resize(int(model_configuration['image_size']/model_configuration['crop_pct']), interpolation = model_configuration['dataset_interpolation']),
                transforms.CenterCrop(model_configuration['image_size']),
                transforms.ToTensor(),
                normalize])),
            batch_size=256, shuffle=False, num_workers=16, pin_memory=False)

        state_dim = model_configuration['feature_map_channels'] * math.ceil(patch_size/32) * math.ceil(patch_size/32)
        
        memory = Memory()
        policy = ActorCritic(model_configuration['feature_map_channels'], state_dim, model_configuration['policy_hidden_dim'], model_configuration['policy_conv'])
        fc = Full_layer(model_configuration['feature_num'], model_configuration['fc_hidden_dim'], model_configuration['fc_rnn'])

        model = nn.DataParallel(model.cuda())
        model_prime = nn.DataParallel(model_prime.cuda())
        policy = policy.cuda()
        fc = fc.cuda()

        model.load_state_dict(checkpoint['model_state_dict'])
        model_prime.load_state_dict(checkpoint['model_prime_state_dict'])
        fc.load_state_dict(checkpoint['fc'])
        policy.load_state_dict(checkpoint['policy'])

        budgeted_batch_flops_list = []
        budgeted_batch_acc_list = []

        print('generate logits on test samples...')
        test_logits, test_targets, anytime_classification = generate_logits(model_prime, model, fc, memory, policy, val_loader, maximum_length, prime_size, patch_size, model_arch)
        
        if args.eval_mode == 2:
            print('generate logits on training samples...')
            dynamic_threshold = torch.zeros([39, maximum_length])
            train_logits, train_targets, _ = generate_logits(model_prime, model, fc, memory, policy, train_loader, maximum_length, prime_size, patch_size, model_arch)

        for p in range(1, 40):

            print('inference: {}/40'.format(p))

            _p = torch.FloatTensor(1).fill_(p * 1.0 / 20)
            probs = torch.exp(torch.log(_p) * torch.range(1, maximum_length))
            probs /= probs.sum()

            if args.eval_mode == 2:
                dynamic_threshold[p-1] = dynamic_find_threshold(train_logits, train_targets, probs)
            
            acc_step, flops_step = dynamic_evaluate(test_logits, test_targets, flops, dynamic_threshold[p-1])
            
            budgeted_batch_acc_list.append(acc_step)
            budgeted_batch_flops_list.append(flops_step)
        
        budgeted_batch_classification = [budgeted_batch_flops_list, budgeted_batch_acc_list]

    print('model_arch :', model_arch)
    print('patch_size :', patch_size)
    print('flops :', flops)
    print('model_flops :', model_flops)
    print('policy_flops :', policy_flops)
    print('fc_flops :', fc_flops)
    print('anytime_classification :', anytime_classification)
    print('budgeted_batch_classification :', budgeted_batch_classification)


def generate_logits(model_prime, model, fc, memory, policy, dataloader, maximum_length, prime_size, patch_size, model_arch):

    logits_list = []
    targets_list = []

    top1 = [AverageMeter() for _ in range(maximum_length)]
    model.eval()
    model_prime.eval()
    fc.eval()

    for i, (x, target) in enumerate(dataloader):

        logits_temp = torch.zeros(maximum_length, x.size(0), 1000)

        target_var = target.cuda()
        input_var = x.cuda()

        input_prime = get_prime(input_var, prime_size, model_configurations[model_arch]['prime_interpolation'])

        with torch.no_grad():

            output, state = model_prime(input_prime)
            
            if 'resnet' in model_arch or 'densenet' in model_arch:
                output = fc(output, restart=True)
            elif 'regnet' in model_arch:
                _ = fc(output, restart=True)
                output = model_prime.module.fc(output)
            else:
                _ = fc(output, restart=True)
                output = model_prime.module.classifier(output)

            logits_temp[0] = F.softmax(output, 1)
            acc = accuracy(output, target_var, topk=(1,))
            top1[0].update(acc.sum(0).mul_(100.0 / x.size(0)).data.item(), x.size(0))
            
            for patch_step in range(1, maximum_length):

                with torch.no_grad():
                    if patch_step == 1:
                        action = policy.act(state, memory, restart_batch=True)
                    else:
                        action = policy.act(state, memory)

                patches = get_patch(input_var, action, patch_size)
                output, state = model(patches)
                output = fc(output, restart=False)

                logits_temp[patch_step] = F.softmax(output, 1)
                acc = accuracy(output, target_var, topk=(1,))
                top1[patch_step].update(acc.sum(0).mul_(100.0 / x.size(0)).data.item(), x.size(0))

        logits_list.append(logits_temp)
        targets_list.append(target_var)

        memory.clear_memory()

        anytime_classification = []

        for index in range(maximum_length):
            anytime_classification.append(top1[index].ave)

    return torch.cat(logits_list, 1), torch.cat(targets_list, 0), anytime_classification


def dynamic_find_threshold(logits, targets, p):

    n_stage, n_sample, c = logits.size()
    max_preds, argmax_preds = logits.max(dim=2, keepdim=False)
    _, sorted_idx = max_preds.sort(dim=1, descending=True)

    filtered = torch.zeros(n_sample)
    T = torch.Tensor(n_stage).fill_(1e8)

    for k in range(n_stage - 1):
        acc, count = 0.0, 0
        out_n = math.floor(n_sample * p[k])
        for i in range(n_sample):
            ori_idx = sorted_idx[k][i]
            if filtered[ori_idx] == 0:
                count += 1
                if count == out_n:
                    T[k] = max_preds[k][ori_idx]
                    break
        filtered.add_(max_preds[k].ge(T[k]).type_as(filtered))

    T[n_stage - 1] = -1e8
    return T
   

def dynamic_evaluate(logits, targets, flops, T):

    n_stage, n_sample, c = logits.size()
    max_preds, argmax_preds = logits.max(dim=2, keepdim=False)
    _, sorted_idx = max_preds.sort(dim=1, descending=True)

    acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage)
    acc, expected_flops = 0, 0
    for i in range(n_sample):
        gold_label = targets[i]
        for k in range(n_stage):
            if max_preds[k][i].item() >= T[k]:  # force the sample to exit at k
                if int(gold_label.item()) == int(argmax_preds[k][i].item()):
                    acc += 1
                    acc_rec[k] += 1
                exp[k] += 1
                break
    acc_all = 0
    for k in range(n_stage):
        _t = 1.0 * exp[k] / n_sample
        expected_flops += _t * flops[k]
        acc_all += acc_rec[k]

    return acc * 100.0 / n_sample, expected_flops.item()


if __name__ == '__main__':
    main()


================================================
FILE: models/__init__.py
================================================
from .gen_efficientnet import *
from .mobilenetv3 import *
from .model_factory import create_model
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
from .activations import *
from .resnet import *
from .densenet import *

================================================
FILE: models/activations/__init__.py
================================================
from .config import *
from .activations_autofn import *
from .activations_jit import *
from .activations import *


_ACT_FN_DEFAULT = dict(
    swish=swish,
    mish=mish,
    relu=F.relu,
    relu6=F.relu6,
    sigmoid=sigmoid,
    tanh=tanh,
    hard_sigmoid=hard_sigmoid,
    hard_swish=hard_swish,
)

_ACT_FN_AUTO = dict(
    swish=swish_auto,
    mish=mish_auto,
)

_ACT_FN_JIT = dict(
    swish=swish_jit,
    mish=mish_jit,
    #hard_swish=hard_swish_jit,
    #hard_sigmoid_jit=hard_sigmoid_jit,
)

_ACT_LAYER_DEFAULT = dict(
    swish=Swish,
    mish=Mish,
    relu=nn.ReLU,
    relu6=nn.ReLU6,
    sigmoid=Sigmoid,
    tanh=Tanh,
    hard_sigmoid=HardSigmoid,
    hard_swish=HardSwish,
)

_ACT_LAYER_AUTO = dict(
    swish=SwishAuto,
    mish=MishAuto,
)

_ACT_LAYER_JIT = dict(
    swish=SwishJit,
    mish=MishJit,
    #hard_swish=HardSwishJit,
    #hard_sigmoid=HardSigmoidJit
)

_OVERRIDE_FN = dict()
_OVERRIDE_LAYER = dict()


def add_override_act_fn(name, fn):
    global _OVERRIDE_FN
    _OVERRIDE_FN[name] = fn


def update_override_act_fn(overrides):
    assert isinstance(overrides, dict)
    global _OVERRIDE_FN
    _OVERRIDE_FN.update(overrides)


def clear_override_act_fn():
    global _OVERRIDE_FN
    _OVERRIDE_FN = dict()


def add_override_act_layer(name, fn):
    _OVERRIDE_LAYER[name] = fn


def update_override_act_layer(overrides):
    assert isinstance(overrides, dict)
    global _OVERRIDE_LAYER
    _OVERRIDE_LAYER.update(overrides)


def clear_override_act_layer():
    global _OVERRIDE_LAYER
    _OVERRIDE_LAYER = dict()


def get_act_fn(name='relu'):
    """ Activation Function Factory
    Fetching activation fns by name with this function allows export or torch script friendly
    functions to be returned dynamically based on current config.
    """
    if name in _OVERRIDE_FN:
        return _OVERRIDE_FN[name]
    if not config.is_exportable() and not config.is_scriptable():
        # If not exporting or scripting the model, first look for a JIT optimized version
        # of our activation, then a custom autograd.Function variant before defaulting to
        # a Python or Torch builtin impl
        if name in _ACT_FN_JIT:
            return _ACT_FN_JIT[name]
        if name in _ACT_FN_AUTO:
            return _ACT_FN_AUTO[name]
    return _ACT_FN_DEFAULT[name]


def get_act_layer(name='relu'):
    """ Activation Layer Factory
    Fetching activation layers by name with this function allows export or torch script friendly
    functions to be returned dynamically based on current config.
    """
    if name in _OVERRIDE_LAYER:
        return _OVERRIDE_LAYER[name]
    if not config.is_exportable() and not config.is_scriptable():
        if name in _ACT_LAYER_JIT:
            return _ACT_LAYER_JIT[name]
        if name in _ACT_LAYER_AUTO:
            return _ACT_LAYER_AUTO[name]
    return _ACT_LAYER_DEFAULT[name]




================================================
FILE: models/activations/activations.py
================================================
from torch import nn as nn
from torch.nn import functional as F


def swish(x, inplace: bool = False):
    """Swish - Described in: https://arxiv.org/abs/1710.05941
    """
    return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())


class Swish(nn.Module):
    def __init__(self, inplace: bool = False):
        super(Swish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return swish(x, self.inplace)


def mish(x, inplace: bool = False):
    """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
    """
    return x.mul(F.softplus(x).tanh())


class Mish(nn.Module):
    def __init__(self, inplace: bool = False):
        super(Mish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return mish(x, self.inplace)


def sigmoid(x, inplace: bool = False):
    return x.sigmoid_() if inplace else x.sigmoid()


# PyTorch has this, but not with a consistent inplace argmument interface
class Sigmoid(nn.Module):
    def __init__(self, inplace: bool = False):
        super(Sigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return x.sigmoid_() if self.inplace else x.sigmoid()


def tanh(x, inplace: bool = False):
    return x.tanh_() if inplace else x.tanh()


# PyTorch has this, but not with a consistent inplace argmument interface
class Tanh(nn.Module):
    def __init__(self, inplace: bool = False):
        super(Tanh, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return x.tanh_() if self.inplace else x.tanh()


def hard_swish(x, inplace: bool = False):
    inner = F.relu6(x + 3.).div_(6.)
    return x.mul_(inner) if inplace else x.mul(inner)


class HardSwish(nn.Module):
    def __init__(self, inplace: bool = False):
        super(HardSwish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_swish(x, self.inplace)


def hard_sigmoid(x, inplace: bool = False):
    if inplace:
        return x.add_(3.).clamp_(0., 6.).div_(6.)
    else:
        return F.relu6(x + 3.) / 6.


class HardSigmoid(nn.Module):
    def __init__(self, inplace: bool = False):
        super(HardSigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_sigmoid(x, self.inplace)




================================================
FILE: models/activations/activations_autofn.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F


__all__ = ['swish_auto', 'SwishAuto', 'mish_auto', 'MishAuto']


class SwishAutoFn(torch.autograd.Function):
    """Swish - Described in: https://arxiv.org/abs/1710.05941
    Memory efficient variant from:
     https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76
    """
    @staticmethod
    def forward(ctx, x):
        result = x.mul(torch.sigmoid(x))
        ctx.save_for_backward(x)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        x_sigmoid = torch.sigmoid(x)
        return grad_output.mul(x_sigmoid * (1 + x * (1 - x_sigmoid)))


def swish_auto(x, inplace=False):
    # inplace ignored
    return SwishAutoFn.apply(x)


class SwishAuto(nn.Module):
    def __init__(self, inplace: bool = False):
        super(SwishAuto, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return SwishAutoFn.apply(x)


class MishAutoFn(torch.autograd.Function):
    """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
    Experimental memory-efficient variant
    """

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        y = x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        x_sigmoid = torch.sigmoid(x)
        x_tanh_sp = F.softplus(x).tanh()
        return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))


def mish_auto(x, inplace=False):
    # inplace ignored
    return MishAutoFn.apply(x)


class MishAuto(nn.Module):
    def __init__(self, inplace: bool = False):
        super(MishAuto, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return MishAutoFn.apply(x)



================================================
FILE: models/activations/activations_jit.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F


__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit']
           #'hard_swish_jit', 'HardSwishJit', 'hard_sigmoid_jit', 'HardSigmoidJit']


@torch.jit.script
def swish_jit_fwd(x):
    return x.mul(torch.sigmoid(x))


@torch.jit.script
def swish_jit_bwd(x, grad_output):
    x_sigmoid = torch.sigmoid(x)
    return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))


class SwishJitAutoFn(torch.autograd.Function):
    """ torch.jit.script optimised Swish
    Inspired by conversation btw Jeremy Howard & Adam Pazske
    https://twitter.com/jeremyphoward/status/1188251041835315200
    """
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return swish_jit_fwd(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        return swish_jit_bwd(x, grad_output)


def swish_jit(x, inplace=False):
    # inplace ignored
    return SwishJitAutoFn.apply(x)


class SwishJit(nn.Module):
    def __init__(self, inplace: bool = False):
        super(SwishJit, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return SwishJitAutoFn.apply(x)


@torch.jit.script
def mish_jit_fwd(x):
    return x.mul(torch.tanh(F.softplus(x)))


@torch.jit.script
def mish_jit_bwd(x, grad_output):
    x_sigmoid = torch.sigmoid(x)
    x_tanh_sp = F.softplus(x).tanh()
    return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))


class MishJitAutoFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return mish_jit_fwd(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        return mish_jit_bwd(x, grad_output)


def mish_jit(x, inplace=False):
    # inplace ignored
    return MishJitAutoFn.apply(x)


class MishJit(nn.Module):
    def __init__(self, inplace: bool = False):
        super(MishJit, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return MishJitAutoFn.apply(x)


# @torch.jit.script
# def hard_swish_jit(x, inplac: bool = False):
#     return x.mul(F.relu6(x + 3.).mul_(1./6.))
#
#
# class HardSwishJit(nn.Module):
#     def __init__(self, inplace: bool = False):
#         super(HardSwishJit, self).__init__()
#
#     def forward(self, x):
#         return hard_swish_jit(x)
#
#
# @torch.jit.script
# def hard_sigmoid_jit(x, inplace: bool = False):
#     return F.relu6(x + 3.).mul(1./6.)
#
#
# class HardSigmoidJit(nn.Module):
#     def __init__(self, inplace: bool = False):
#         super(HardSigmoidJit, self).__init__()
#
#     def forward(self, x):
#         return hard_sigmoid_jit(x)


================================================
FILE: models/activations/config.py
================================================
""" Global Config and Constants
"""

__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable']

# Set to True if exporting a model with Same padding via ONNX
_EXPORTABLE = False

# Set to True if wanting to use torch.jit.script on a model
_SCRIPTABLE = False


def is_exportable():
    return _EXPORTABLE


def set_exportable(value):
    global _EXPORTABLE
    _EXPORTABLE = value


def is_scriptable():
    return _SCRIPTABLE


def set_scriptable(value):
    global _SCRIPTABLE
    _SCRIPTABLE = value



================================================
FILE: models/config.py
================================================
""" Global Config and Constants
"""

__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable']

# Set to True if exporting a model with Same padding via ONNX
_EXPORTABLE = False

# Set to True if wanting to use torch.jit.script on a model
_SCRIPTABLE = False


def is_exportable():
    return _EXPORTABLE


def set_exportable(value):
    global _EXPORTABLE
    _EXPORTABLE = value


def is_scriptable():
    return _SCRIPTABLE


def set_scriptable(value):
    global _SCRIPTABLE
    _SCRIPTABLE = value



================================================
FILE: models/conv2d_layers.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._six import container_abcs

from itertools import repeat
from functools import partial
from typing import Union, List, Tuple, Optional, Callable
import numpy as np
import math

from .activations.config import *


def _ntuple(n):
    def parse(x):
        if isinstance(x, container_abcs.Iterable):
            return x
        return tuple(repeat(x, n))
    return parse


_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)


def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0


def _get_padding(kernel_size, stride=1, dilation=1, **_):
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding


def _calc_same_pad(i: int, k: int, s: int, d: int):
    return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)


def _same_pad_arg(input_size, kernel_size, stride, dilation):
    ih, iw = input_size
    kh, kw = kernel_size
    pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
    pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
    return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]


def _split_channels(num_chan, num_groups):
    split = [num_chan // num_groups for _ in range(num_groups)]
    split[0] += num_chan - sum(split)
    return split


def conv2d_same(
        x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
        padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
    ih, iw = x.size()[-2:]
    kh, kw = weight.size()[-2:]
    pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
    pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
    return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)


class Conv2dSame(nn.Conv2d):
    """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
    """

    # pylint: disable=unused-argument
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2dSame, self).__init__(
            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)

    def forward(self, x):
        return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


class Conv2dSameExport(nn.Conv2d):
    """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions

    NOTE: This does not currently work with torch.jit.script
    """

    # pylint: disable=unused-argument
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2dSameExport, self).__init__(
            in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
        self.pad = None
        self.pad_input_size = (0, 0)

    def forward(self, x):
        input_size = x.size()[-2:]
        if self.pad is None:
            pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
            self.pad = nn.ZeroPad2d(pad_arg)
            self.pad_input_size = input_size
        else:
            assert self.pad_input_size == input_size

        x = self.pad(x)
        return F.conv2d(
            x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


def get_padding_value(padding, kernel_size, **kwargs):
    dynamic = False
    if isinstance(padding, str):
        # for any string padding, the padding will be calculated for you, one of three ways
        padding = padding.lower()
        if padding == 'same':
            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
            if _is_static_pad(kernel_size, **kwargs):
                # static case, no extra overhead
                padding = _get_padding(kernel_size, **kwargs)
            else:
                # dynamic padding
                padding = 0
                dynamic = True
        elif padding == 'valid':
            # 'VALID' padding, same as padding=0
            padding = 0
        else:
            # Default to PyTorch style 'same'-ish symmetric padding
            padding = _get_padding(kernel_size, **kwargs)
    return padding, dynamic


def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
    padding = kwargs.pop('padding', '')
    kwargs.setdefault('bias', False)
    padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
    if is_dynamic:
        if is_exportable():
            assert not is_scriptable()
            return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
        else:
            return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
    else:
        return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)


class MixedConv2d(nn.ModuleDict):
    """ Mixed Grouped Convolution
    Based on MDConv and GroupedConv in MixNet impl:
      https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
    """

    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding='', dilation=1, depthwise=False, **kwargs):
        super(MixedConv2d, self).__init__()

        kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
        num_groups = len(kernel_size)
        in_splits = _split_channels(in_channels, num_groups)
        out_splits = _split_channels(out_channels, num_groups)
        self.in_channels = sum(in_splits)
        self.out_channels = sum(out_splits)
        for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
            conv_groups = out_ch if depthwise else 1
            self.add_module(
                str(idx),
                create_conv2d_pad(
                    in_ch, out_ch, k, stride=stride,
                    padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
            )
        self.splits = in_splits

    def forward(self, x):
        x_split = torch.split(x, self.splits, 1)
        x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
        x = torch.cat(x_out, 1)
        return x


def get_condconv_initializer(initializer, num_experts, expert_shape):
    def condconv_initializer(weight):
        """CondConv initializer function."""
        num_params = np.prod(expert_shape)
        if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
                weight.shape[1] != num_params):
            raise (ValueError(
                'CondConv variables must have shape [num_experts, num_params]'))
        for i in range(num_experts):
            initializer(weight[i].view(expert_shape))
    return condconv_initializer


class CondConv2d(nn.Module):
    """ Conditional Convolution
    Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py

    Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
    https://github.com/pytorch/pytorch/issues/17983
    """
    __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']

    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
        super(CondConv2d, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        padding_val, is_padding_dynamic = get_padding_value(
            padding, kernel_size, stride=stride, dilation=dilation)
        self.dynamic_padding = is_padding_dynamic  # if in forward to work with torchscript
        self.padding = _pair(padding_val)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.num_experts = num_experts

        self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
        weight_num_param = 1
        for wd in self.weight_shape:
            weight_num_param *= wd
        self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))

        if bias:
            self.bias_shape = (self.out_channels,)
            self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        init_weight = get_condconv_initializer(
            partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
        init_weight(self.weight)
        if self.bias is not None:
            fan_in = np.prod(self.weight_shape[1:])
            bound = 1 / math.sqrt(fan_in)
            init_bias = get_condconv_initializer(
                partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
            init_bias(self.bias)

    def forward(self, x, routing_weights):
        B, C, H, W = x.shape
        weight = torch.matmul(routing_weights, self.weight)
        new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
        weight = weight.view(new_weight_shape)
        bias = None
        if self.bias is not None:
            bias = torch.matmul(routing_weights, self.bias)
            bias = bias.view(B * self.out_channels)
        # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
        x = x.view(1, B * C, H, W)
        if self.dynamic_padding:
            out = conv2d_same(
                x, weight, bias, stride=self.stride, padding=self.padding,
                dilation=self.dilation, groups=self.groups * B)
        else:
            out = F.conv2d(
                x, weight, bias, stride=self.stride, padding=self.padding,
                dilation=self.dilation, groups=self.groups * B)
        out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])

        # Literal port (from TF definition)
        # x = torch.split(x, 1, 0)
        # weight = torch.split(weight, 1, 0)
        # if self.bias is not None:
        #     bias = torch.matmul(routing_weights, self.bias)
        #     bias = torch.split(bias, 1, 0)
        # else:
        #     bias = [None] * B
        # out = []
        # for xi, wi, bi in zip(x, weight, bias):
        #     wi = wi.view(*self.weight_shape)
        #     if bi is not None:
        #         bi = bi.view(*self.bias_shape)
        #     out.append(self.conv_fn(
        #         xi, wi, bi, stride=self.stride, padding=self.padding,
        #         dilation=self.dilation, groups=self.groups))
        # out = torch.cat(out, 0)
        return out


def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
    assert 'groups' not in kwargs  # only use 'depthwise' bool arg
    if isinstance(kernel_size, list):
        assert 'num_experts' not in kwargs  # MixNet + CondConv combo not supported currently
        # We're going to use only lists for defining the MixedConv2d kernel groups,
        # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
        m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
    else:
        depthwise = kwargs.pop('depthwise', False)
        groups = out_chs if depthwise else 1
        if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
            m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
        else:
            m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
    return m


================================================
FILE: models/densenet.py
================================================

import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']


model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}



def densenet121(pretrained=False, **kwargs):
    r"""Densenet-121 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
                     **kwargs)
    if pretrained:
        # model.load_state_dict(torch.load(model_urls['densenet121']))
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = torch.load(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model



def densenet169(pretrained=False, **kwargs):
    r"""Densenet-169 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = torch.load(model_urls['densenet169'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model



def densenet201(pretrained=False, **kwargs):
    r"""Densenet-201 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = torch.load(model_urls['densenet201'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model



def densenet161(pretrained=False, **kwargs):
    r"""Densenet-161 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = torch.load(model_urls['densenet161'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model



class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                        growth_rate, kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                        kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        self.feature_num = num_features
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal(m.weight.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        features = self.features(x)
        x = F.relu(features, inplace=True)
        out = self.avgpool(x).view(features.size(0), -1)
        return out, x.detach()


================================================
FILE: models/efficientnet_builder.py
================================================
import re
from copy import deepcopy

from .conv2d_layers import *
from .activations import *


# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
#
# PyTorch defaults are momentum = .1, eps = 1e-5
#
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)


def get_bn_args_tf():
    return _BN_ARGS_TF.copy()


def resolve_bn_args(kwargs):
    bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
    bn_momentum = kwargs.pop('bn_momentum', None)
    if bn_momentum is not None:
        bn_args['momentum'] = bn_momentum
    bn_eps = kwargs.pop('bn_eps', None)
    if bn_eps is not None:
        bn_args['eps'] = bn_eps
    return bn_args


_SE_ARGS_DEFAULT = dict(
    gate_fn=sigmoid,
    act_layer=None,  # None == use containing block's activation layer
    reduce_mid=False,
    divisor=1)


def resolve_se_args(kwargs, in_chs, act_layer=None):
    se_kwargs = kwargs.copy() if kwargs is not None else {}
    # fill in args that aren't specified with the defaults
    for k, v in _SE_ARGS_DEFAULT.items():
        se_kwargs.setdefault(k, v)
    # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
    if not se_kwargs.pop('reduce_mid'):
        se_kwargs['reduced_base_chs'] = in_chs
    # act_layer override, if it remains None, the containing block's act_layer will be used
    if se_kwargs['act_layer'] is None:
        assert act_layer is not None
        se_kwargs['act_layer'] = act_layer
    return se_kwargs


def resolve_act_layer(kwargs, default='relu'):
    act_layer = kwargs.pop('act_layer', default)
    if isinstance(act_layer, str):
        act_layer = get_act_layer(act_layer)
    return act_layer


def make_divisible(v: int, divisor: int = 8, min_value: int = None):
    min_value = min_value or divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:  # ensure round down does not go down by more than 10%.
        new_v += divisor
    return new_v


def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
    """Round number of filters based on depth multiplier."""
    if not multiplier:
        return channels
    channels *= multiplier
    return make_divisible(channels, divisor, channel_min)


def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
    """Apply drop connect."""
    if not training:
        return inputs

    keep_prob = 1 - drop_connect_rate
    random_tensor = keep_prob + torch.rand(
        (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
    random_tensor.floor_()  # binarize
    output = inputs.div(keep_prob) * random_tensor
    return output


class SqueezeExcite(nn.Module):

    def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1):
        super(SqueezeExcite, self).__init__()
        self.gate_fn = gate_fn
        reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
        self.act1 = act_layer(inplace=True)
        self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)

    def forward(self, x):
        # tensor.view + mean bad for ONNX export (produces mess of gather ops that break TensorRT)
        x_se = self.avg_pool(x)
        x_se = self.conv_reduce(x_se)
        x_se = self.act1(x_se)
        x_se = self.conv_expand(x_se)
        x = x * self.gate_fn(x_se)
        return x


class ConvBnAct(nn.Module):
    def __init__(self, in_chs, out_chs, kernel_size,
                 stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
        super(ConvBnAct, self).__init__()
        assert stride in [1, 2]
        norm_kwargs = norm_kwargs or {}
        self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
        self.bn1 = norm_layer(out_chs, **norm_kwargs)
        self.act1 = act_layer(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn1(x)
        x = self.act1(x)
        return x


class DepthwiseSeparableConv(nn.Module):
    """ DepthwiseSeparable block
    Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
    factor of 1.0. This is an alternative to having a IR with optional first pw conv.
    """
    def __init__(self, in_chs, out_chs, dw_kernel_size=3,
                 stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
                 pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
                 norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
        super(DepthwiseSeparableConv, self).__init__()
        assert stride in [1, 2]
        norm_kwargs = norm_kwargs or {}
        self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
        self.drop_connect_rate = drop_connect_rate

        self.conv_dw = select_conv2d(
            in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
        self.bn1 = norm_layer(in_chs, **norm_kwargs)
        self.act1 = act_layer(inplace=True)

        # Squeeze-and-excitation
        if se_ratio is not None and se_ratio > 0.:
            se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
            self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
        else:
            self.se = nn.Identity()

        self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
        self.bn2 = norm_layer(out_chs, **norm_kwargs)
        self.act2 = act_layer(inplace=True) if pw_act else nn.Identity()

    def forward(self, x):
        residual = x

        x = self.conv_dw(x)
        x = self.bn1(x)
        x = self.act1(x)

        x = self.se(x)

        x = self.conv_pw(x)
        x = self.bn2(x)
        x = self.act2(x)

        if self.has_residual:
            if self.drop_connect_rate > 0.:
                x = drop_connect(x, self.training, self.drop_connect_rate)
            x += residual
        return x


class InvertedResidual(nn.Module):
    """ Inverted residual block w/ optional SE"""

    def __init__(self, in_chs, out_chs, dw_kernel_size=3,
                 stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
                 exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
                 se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
                 conv_kwargs=None, drop_connect_rate=0.):
        super(InvertedResidual, self).__init__()
        norm_kwargs = norm_kwargs or {}
        conv_kwargs = conv_kwargs or {}
        mid_chs: int = make_divisible(in_chs * exp_ratio)
        self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
        self.drop_connect_rate = drop_connect_rate

        # Point-wise expansion
        self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
        self.bn1 = norm_layer(mid_chs, **norm_kwargs)
        self.act1 = act_layer(inplace=True)

        # Depth-wise convolution
        self.conv_dw = select_conv2d(
            mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs)
        self.bn2 = norm_layer(mid_chs, **norm_kwargs)
        self.act2 = act_layer(inplace=True)

        # Squeeze-and-excitation
        if se_ratio is not None and se_ratio > 0.:
            se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
            self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
        else:
            self.se = nn.Identity()  # for jit.script compat

        # Point-wise linear projection
        self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
        self.bn3 = norm_layer(out_chs, **norm_kwargs)

    def forward(self, x):
        residual = x

        # Point-wise expansion
        x = self.conv_pw(x)
        x = self.bn1(x)
        x = self.act1(x)

        # Depth-wise convolution
        x = self.conv_dw(x)
        x = self.bn2(x)
        x = self.act2(x)

        # Squeeze-and-excitation
        x = self.se(x)

        # Point-wise linear projection
        x = self.conv_pwl(x)
        x = self.bn3(x)

        if self.has_residual:
            if self.drop_connect_rate > 0.:
                x = drop_connect(x, self.training, self.drop_connect_rate)
            x += residual
        return x


class CondConvResidual(InvertedResidual):
    """ Inverted residual block w/ CondConv routing"""

    def __init__(self, in_chs, out_chs, dw_kernel_size=3,
                 stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
                 exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
                 se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
                 num_experts=0, drop_connect_rate=0.):

        self.num_experts = num_experts
        conv_kwargs = dict(num_experts=self.num_experts)

        super(CondConvResidual, self).__init__(
            in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type,
            act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
            pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
            norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
            drop_connect_rate=drop_connect_rate)

        self.routing_fn = nn.Linear(in_chs, self.num_experts)

    def forward(self, x):
        residual = x

        # CondConv routing
        pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
        routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))

        # Point-wise expansion
        x = self.conv_pw(x, routing_weights)
        x = self.bn1(x)
        x = self.act1(x)

        # Depth-wise convolution
        x = self.conv_dw(x, routing_weights)
        x = self.bn2(x)
        x = self.act2(x)

        # Squeeze-and-excitation
        x = self.se(x)

        # Point-wise linear projection
        x = self.conv_pwl(x, routing_weights)
        x = self.bn3(x)

        if self.has_residual:
            if self.drop_connect_rate > 0.:
                x = drop_connect(x, self.training, self.drop_connect_rate)
            x += residual
        return x


class EdgeResidual(nn.Module):
    """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride"""

    def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
                 stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
                 se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
        super(EdgeResidual, self).__init__()
        norm_kwargs = norm_kwargs or {}
        mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio)
        self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
        self.drop_connect_rate = drop_connect_rate

        # Expansion convolution
        self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
        self.bn1 = norm_layer(mid_chs, **norm_kwargs)
        self.act1 = act_layer(inplace=True)

        # Squeeze-and-excitation
        if se_ratio is not None and se_ratio > 0.:
            se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
            self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
        else:
            self.se = nn.Identity()

        # Point-wise linear projection
        self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
        self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs)

    def forward(self, x):
        residual = x

        # Expansion convolution
        x = self.conv_exp(x)
        x = self.bn1(x)
        x = self.act1(x)

        # Squeeze-and-excitation
        x = self.se(x)

        # Point-wise linear projection
        x = self.conv_pwl(x)
        x = self.bn2(x)

        if self.has_residual:
            if self.drop_connect_rate > 0.:
                x = drop_connect(x, self.training, self.drop_connect_rate)
            x += residual

        return x


class EfficientNetBuilder:
    """ Build Trunk Blocks for Efficient/Mobile Networks

    This ended up being somewhat of a cross between
    https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
    and
    https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py

    """

    def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
                 pad_type='', act_layer=None, se_kwargs=None,
                 norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
        self.channel_multiplier = channel_multiplier
        self.channel_divisor = channel_divisor
        self.channel_min = channel_min
        self.pad_type = pad_type
        self.act_layer = act_layer
        self.se_kwargs = se_kwargs
        self.norm_layer = norm_layer
        self.norm_kwargs = norm_kwargs
        self.drop_connect_rate = drop_connect_rate

        # updated during build
        self.in_chs = None
        self.block_idx = 0
        self.block_count = 0

    def _round_channels(self, chs):
        return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)

    def _make_block(self, ba):
        bt = ba.pop('block_type')
        ba['in_chs'] = self.in_chs
        ba['out_chs'] = self._round_channels(ba['out_chs'])
        if 'fake_in_chs' in ba and ba['fake_in_chs']:
            # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU
            ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
        ba['norm_layer'] = self.norm_layer
        ba['norm_kwargs'] = self.norm_kwargs
        ba['pad_type'] = self.pad_type
        # block act fn overrides the model default
        ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
        assert ba['act_layer'] is not None
        if bt == 'ir':
            ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
            ba['se_kwargs'] = self.se_kwargs
            if ba.get('num_experts', 0) > 0:
                block = CondConvResidual(**ba)
            else:
                block = InvertedResidual(**ba)
        elif bt == 'ds' or bt == 'dsa':
            ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
            ba['se_kwargs'] = self.se_kwargs
            block = DepthwiseSeparableConv(**ba)
        elif bt == 'er':
            ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
            ba['se_kwargs'] = self.se_kwargs
            block = EdgeResidual(**ba)
        elif bt == 'cn':
            block = ConvBnAct(**ba)
        else:
            assert False, 'Uknkown block type (%s) while building model.' % bt
        self.in_chs = ba['out_chs']  # update in_chs for arg of next block
        return block

    def _make_stack(self, stack_args):
        blocks = []
        # each stack (stage) contains a list of block arguments
        for i, ba in enumerate(stack_args):
            if i >= 1:
                # only the first block in any stack can have a stride > 1
                ba['stride'] = 1
            block = self._make_block(ba)
            blocks.append(block)
            self.block_idx += 1  # incr global idx (across all stacks)
        return nn.Sequential(*blocks)

    def __call__(self, in_chs, block_args):
        """ Build the blocks
        Args:
            in_chs: Number of input-channels passed to first block
            block_args: A list of lists, outer list defines stages, inner
                list contains strings defining block configuration(s)
        Return:
             List of block stacks (each stack wrapped in nn.Sequential)
        """
        self.in_chs = in_chs
        self.block_count = sum([len(x) for x in block_args])
        self.block_idx = 0
        blocks = []
        # outer list of block_args defines the stacks ('stages' by some conventions)
        for stack_idx, stack in enumerate(block_args):
            assert isinstance(stack, list)
            stack = self._make_stack(stack)
            blocks.append(stack)
        return blocks


def _parse_ksize(ss):
    if ss.isdigit():
        return int(ss)
    else:
        return [int(k) for k in ss.split('.')]


def _decode_block_str(block_str):
    """ Decode block definition string

    Gets a list of block arg (dicts) through a string notation of arguments.
    E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip

    All args can exist in any order with the exception of the leading string which
    is assumed to indicate the block type.

    leading string - block type (
      ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
    r - number of repeat blocks,
    k - kernel size,
    s - strides (1-9),
    e - expansion ratio,
    c - output channels,
    se - squeeze/excitation ratio
    n - activation fn ('re', 'r6', 'hs', or 'sw')
    Args:
        block_str: a string representation of block arguments.
    Returns:
        A list of block args (dicts)
    Raises:
        ValueError: if the string def not properly specified (TODO)
    """
    assert isinstance(block_str, str)
    ops = block_str.split('_')
    block_type = ops[0]  # take the block type off the front
    ops = ops[1:]
    options = {}
    noskip = False
    for op in ops:
        # string options being checked on individual basis, combine if they grow
        if op == 'noskip':
            noskip = True
        elif op.startswith('n'):
            # activation fn
            key = op[0]
            v = op[1:]
            if v == 're':
                value = get_act_layer('relu')
            elif v == 'r6':
                value = get_act_layer('relu6')
            elif v == 'hs':
                value = get_act_layer('hard_swish')
            elif v == 'sw':
                value = get_act_layer('swish')
            else:
                continue
            options[key] = value
        else:
            # all numeric options
            splits = re.split(r'(\d.*)', op)
            if len(splits) >= 2:
                key, value = splits[:2]
                options[key] = value

    # if act_layer is None, the model default (passed to model init) will be used
    act_layer = options['n'] if 'n' in options else None
    exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
    pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
    fake_in_chs = int(options['fc']) if 'fc' in options else 0  # FIXME hack to deal with in_chs issue in TPU def

    num_repeat = int(options['r'])
    # each type of block has different valid arguments, fill accordingly
    if block_type == 'ir':
        block_args = dict(
            block_type=block_type,
            dw_kernel_size=_parse_ksize(options['k']),
            exp_kernel_size=exp_kernel_size,
            pw_kernel_size=pw_kernel_size,
            out_chs=int(options['c']),
            exp_ratio=float(options['e']),
            se_ratio=float(options['se']) if 'se' in options else None,
            stride=int(options['s']),
            act_layer=act_layer,
            noskip=noskip,
        )
        if 'cc' in options:
            block_args['num_experts'] = int(options['cc'])
    elif block_type == 'ds' or block_type == 'dsa':
        block_args = dict(
            block_type=block_type,
            dw_kernel_size=_parse_ksize(options['k']),
            pw_kernel_size=pw_kernel_size,
            out_chs=int(options['c']),
            se_ratio=float(options['se']) if 'se' in options else None,
            stride=int(options['s']),
            act_layer=act_layer,
            pw_act=block_type == 'dsa',
            noskip=block_type == 'dsa' or noskip,
        )
    elif block_type == 'er':
        block_args = dict(
            block_type=block_type,
            exp_kernel_size=_parse_ksize(options['k']),
            pw_kernel_size=pw_kernel_size,
            out_chs=int(options['c']),
            exp_ratio=float(options['e']),
            fake_in_chs=fake_in_chs,
            se_ratio=float(options['se']) if 'se' in options else None,
            stride=int(options['s']),
            act_layer=act_layer,
            noskip=noskip,
        )
    elif block_type == 'cn':
        block_args = dict(
            block_type=block_type,
            kernel_size=int(options['k']),
            out_chs=int(options['c']),
            stride=int(options['s']),
            act_layer=act_layer,
        )
    else:
        assert False, 'Unknown block type (%s)' % block_type

    return block_args, num_repeat


def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
    """ Per-stage depth scaling
    Scales the block repeats in each stage. This depth scaling impl maintains
    compatibility with the EfficientNet scaling method, while allowing sensible
    scaling for other models that may have multiple block arg definitions in each stage.
    """

    # We scale the total repeat count for each stage, there may be multiple
    # block arg defs per stage so we need to sum.
    num_repeat = sum(repeats)
    if depth_trunc == 'round':
        # Truncating to int by rounding allows stages with few repeats to remain
        # proportionally smaller for longer. This is a good choice when stage definitions
        # include single repeat stages that we'd prefer to keep that way as long as possible
        num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
    else:
        # The default for EfficientNet truncates repeats to int via 'ceil'.
        # Any multiplier > 1.0 will result in an increased depth for every stage.
        num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))

    # Proportionally distribute repeat count scaling to each block definition in the stage.
    # Allocation is done in reverse as it results in the first block being less likely to be scaled.
    # The first block makes less sense to repeat in most of the arch definitions.
    repeats_scaled = []
    for r in repeats[::-1]:
        rs = max(1, round((r / num_repeat * num_repeat_scaled)))
        repeats_scaled.append(rs)
        num_repeat -= r
        num_repeat_scaled -= rs
    repeats_scaled = repeats_scaled[::-1]

    # Apply the calculated scaling to each block arg in the stage
    sa_scaled = []
    for ba, rep in zip(stack_args, repeats_scaled):
        sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
    return sa_scaled


def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
    arch_args = []
    for stack_idx, block_strings in enumerate(arch_def):
        assert isinstance(block_strings, list)
        stack_args = []
        repeats = []
        for block_str in block_strings:
            assert isinstance(block_str, str)
            ba, rep = _decode_block_str(block_str)
            if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
                ba['num_experts'] *= experts_multiplier
            stack_args.append(ba)
            repeats.append(rep)
        if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
            arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
        else:
            arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
    return arch_args


def initialize_weight_goog(m, n='', fix_group_fanout=True):
    # weight init as per Tensorflow Official impl
    # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
    if isinstance(m, CondConv2d):
        fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        if fix_group_fanout:
            fan_out //= m.groups
        init_weight_fn = get_condconv_initializer(
            lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
        init_weight_fn(m.weight)
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.Conv2d):
        fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        if fix_group_fanout:
            fan_out //= m.groups
        m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1.0)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        fan_out = m.weight.size(0)  # fan-out
        fan_in = 0
        if 'routing_fn' in n:
            fan_in = m.weight.size(1)
        init_range = 1.0 / math.sqrt(fan_in + fan_out)
        m.weight.data.uniform_(-init_range, init_range)
        m.bias.data.zero_()


def initialize_weight_default(m, n=''):
    if isinstance(m, CondConv2d):
        init_fn = get_condconv_initializer(partial(
            nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
        init_fn(m.weight)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1.0)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')


================================================
FILE: models/gen_efficientnet.py
================================================
""" Generic Efficient Networks

A generic MobileNet class with building blocks to support a variety of models:

* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports)
  - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
  - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
  - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
  - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252

* EfficientNet-Lite

* MixNet (Small, Medium, and Large)
  - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595

* MNasNet B1, A1 (SE), Small
  - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626

* FBNet-C
  - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443

* Single-Path NAS Pixel1
  - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877

* And likely more...

Hacked together by Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F

from .helpers import load_pretrained
from .efficientnet_builder import *

__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140',
           'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small',
           'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2',  'efficientnet_b3',
           'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8',
           'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el',
           'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e',
           'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4',
           'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3',
           'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8',
           'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap',
           'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap',
           'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns',
           'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns',
           'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475',
           'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el',
           'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e',
           'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3',
           'tf_efficientnet_lite4',
           'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l']


model_urls = {
    'mnasnet_050': None,
    'mnasnet_075': None,
    'mnasnet_100':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
    'mnasnet_140': None,
    'mnasnet_small': None,

    'semnasnet_050': None,
    'semnasnet_075': None,
    'semnasnet_100':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
    'semnasnet_140': None,

    'fbnetc_100':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
    'spnasnet_100':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',

    'efficientnet_b0':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth',
    'efficientnet_b1':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
    'efficientnet_b2': 
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth',
    'efficientnet_b3':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra-a5e2fbc7.pth',
    'efficientnet_b4': None,
    'efficientnet_b5': None,
    'efficientnet_b6': None,
    'efficientnet_b7': None,
    'efficientnet_b8': None,
    'efficientnet_l2': None,

    'efficientnet_es':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
    'efficientnet_em': None,
    'efficientnet_el': None,

    'efficientnet_cc_b0_4e': None,
    'efficientnet_cc_b0_8e': None,
    'efficientnet_cc_b1_8e': None,

    'efficientnet_lite0': None,
    'efficientnet_lite1': None,
    'efficientnet_lite2': None,
    'efficientnet_lite3': None,
    'efficientnet_lite4': None,

    'tf_efficientnet_b0':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
    'tf_efficientnet_b1':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
    'tf_efficientnet_b2':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
    'tf_efficientnet_b3':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
    'tf_efficientnet_b4':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
    'tf_efficientnet_b5':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
    'tf_efficientnet_b6':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
    'tf_efficientnet_b7':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
    'tf_efficientnet_b8':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',

    'tf_efficientnet_b0_ap':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
    'tf_efficientnet_b1_ap':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth',
    'tf_efficientnet_b2_ap':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth',
    'tf_efficientnet_b3_ap':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth',
    'tf_efficientnet_b4_ap':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth',
    'tf_efficientnet_b5_ap':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth',
    'tf_efficientnet_b6_ap':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth',
    'tf_efficientnet_b7_ap':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth',
    'tf_efficientnet_b8_ap':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',

    'tf_efficientnet_b0_ns':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
    'tf_efficientnet_b1_ns':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
    'tf_efficientnet_b2_ns':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
    'tf_efficientnet_b3_ns':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
    'tf_efficientnet_b4_ns':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
    'tf_efficientnet_b5_ns':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
    'tf_efficientnet_b6_ns':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
    'tf_efficientnet_b7_ns':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
    'tf_efficientnet_l2_ns_475':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
    'tf_efficientnet_l2_ns':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',

    'tf_efficientnet_es':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
    'tf_efficientnet_em':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
    'tf_efficientnet_el':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',

    'tf_efficientnet_cc_b0_4e':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth',
    'tf_efficientnet_cc_b0_8e':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth',
    'tf_efficientnet_cc_b1_8e':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',

    'tf_efficientnet_lite0':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
    'tf_efficientnet_lite1':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
    'tf_efficientnet_lite2':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
    'tf_efficientnet_lite3':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
    'tf_efficientnet_lite4':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',

    'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth',
    'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth',
    'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth',
    'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth',

    'tf_mixnet_s':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth',
    'tf_mixnet_m':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth',
    'tf_mixnet_l':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth',
}


class GenEfficientNet(nn.Module):
    """ Generic EfficientNets

    An implementation of mobile optimized networks that covers:
      * EfficientNet (B0-B8, L2, CondConv, EdgeTPU)
      * MixNet (Small, Medium, and Large, XL)
      * MNASNet A1, B1, and small
      * FBNet C
      * Single-Path NAS Pixel1
    """

    def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False,
                 channel_multiplier=1.0, channel_divisor=8, channel_min=None,
                 pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
                 se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
                 weight_init='goog'):
        super(GenEfficientNet, self).__init__()
        self.drop_rate = drop_rate

        if not fix_stem:
            stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
        self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
        self.bn1 = norm_layer(stem_size, **norm_kwargs)
        self.act1 = act_layer(inplace=True)
        in_chs = stem_size

        builder = EfficientNetBuilder(
            channel_multiplier, channel_divisor, channel_min,
            pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate)
        self.blocks = nn.Sequential(*builder(in_chs, block_args))
        in_chs = builder.in_chs

        self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type)
        self.bn2 = norm_layer(num_features, **norm_kwargs)
        self.act2 = act_layer(inplace=True)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(num_features, num_classes)

        for n, m in self.named_modules():
            if weight_init == 'goog':
                initialize_weight_goog(m, n)
            else:
                initialize_weight_default(m, n)
        self.feature_num = num_features

    def features(self, x):
        x = self.conv_stem(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.blocks(x)
        x = self.conv_head(x)
        x = self.bn2(x)
        x = self.act2(x)
        return x

    def as_sequential(self):
        layers = [self.conv_stem, self.bn1, self.act1]
        layers.extend(self.blocks)
        layers.extend([
            self.conv_head, self.bn2, self.act2,
            self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
        return nn.Sequential(*layers)

    def forward(self, x):
        feature = self.features(x)
        x = self.global_pool(feature)
        x = x.flatten(1)
        if self.drop_rate > 0.:
            x = F.dropout(x, p=self.drop_rate, training=self.training)
        return x, feature.detach()


def _create_model(model_kwargs, variant, pretrained=False):
    as_sequential = model_kwargs.pop('as_sequential', False)
    model = GenEfficientNet(**model_kwargs)
    if pretrained:
        load_pretrained(model, model_urls[variant])
    if as_sequential:
        model = model.as_sequential()
    return model


def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
    """Creates a mnasnet-a1 model.

    Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
    Paper: https://arxiv.org/pdf/1807.11626.pdf.

    Args:
      channel_multiplier: multiplier to number of channels per layer.
    """
    arch_def = [
        # stage 0, 112x112 in
        ['ds_r1_k3_s1_e1_c16_noskip'],
        # stage 1, 112x112 in
        ['ir_r2_k3_s2_e6_c24'],
        # stage 2, 56x56 in
        ['ir_r3_k5_s2_e3_c40_se0.25'],
        # stage 3, 28x28 in
        ['ir_r4_k3_s2_e6_c80'],
        # stage 4, 14x14in
        ['ir_r2_k3_s1_e6_c112_se0.25'],
        # stage 5, 14x14in
        ['ir_r3_k5_s2_e6_c160_se0.25'],
        # stage 6, 7x7 in
        ['ir_r1_k3_s1_e6_c320'],
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def),
        stem_size=32,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'relu'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
    """Creates a mnasnet-b1 model.

    Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
    Paper: https://arxiv.org/pdf/1807.11626.pdf.

    Args:
      channel_multiplier: multiplier to number of channels per layer.
    """
    arch_def = [
        # stage 0, 112x112 in
        ['ds_r1_k3_s1_c16_noskip'],
        # stage 1, 112x112 in
        ['ir_r3_k3_s2_e3_c24'],
        # stage 2, 56x56 in
        ['ir_r3_k5_s2_e3_c40'],
        # stage 3, 28x28 in
        ['ir_r3_k5_s2_e6_c80'],
        # stage 4, 14x14in
        ['ir_r2_k3_s1_e6_c96'],
        # stage 5, 14x14in
        ['ir_r4_k5_s2_e6_c192'],
        # stage 6, 7x7 in
        ['ir_r1_k3_s1_e6_c320_noskip']
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def),
        stem_size=32,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'relu'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
    """Creates a mnasnet-b1 model.

    Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
    Paper: https://arxiv.org/pdf/1807.11626.pdf.

    Args:
      channel_multiplier: multiplier to number of channels per layer.
    """
    arch_def = [
        ['ds_r1_k3_s1_c8'],
        ['ir_r1_k3_s2_e3_c16'],
        ['ir_r2_k3_s2_e6_c16'],
        ['ir_r4_k5_s2_e6_c32_se0.25'],
        ['ir_r3_k3_s1_e6_c32_se0.25'],
        ['ir_r3_k5_s2_e6_c88_se0.25'],
        ['ir_r1_k3_s1_e6_c144']
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def),
        stem_size=8,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'relu'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
    """ FBNet-C

        Paper: https://arxiv.org/abs/1812.03443
        Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py

        NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper,
        it was used to confirm some building block details
    """
    arch_def = [
        ['ir_r1_k3_s1_e1_c16'],
        ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'],
        ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'],
        ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'],
        ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'],
        ['ir_r4_k5_s2_e6_c184'],
        ['ir_r1_k3_s1_e6_c352'],
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def),
        stem_size=16,
        num_features=1984,  # paper suggests this, but is not 100% clear
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'relu'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
    """Creates the Single-Path NAS model from search targeted for Pixel1 phone.

    Paper: https://arxiv.org/abs/1904.02877

    Args:
      channel_multiplier: multiplier to number of channels per layer.
    """
    arch_def = [
        # stage 0, 112x112 in
        ['ds_r1_k3_s1_c16_noskip'],
        # stage 1, 112x112 in
        ['ir_r3_k3_s2_e3_c24'],
        # stage 2, 56x56 in
        ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
        # stage 3, 28x28 in
        ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
        # stage 4, 14x14in
        ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
        # stage 5, 14x14in
        ['ir_r4_k5_s2_e6_c192'],
        # stage 6, 7x7 in
        ['ir_r1_k3_s1_e6_c320_noskip']
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def),
        stem_size=32,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'relu'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
    """Creates an EfficientNet model.

    Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
    Paper: https://arxiv.org/abs/1905.11946

    EfficientNet params
    name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
    'efficientnet-b0': (1.0, 1.0, 224, 0.2),
    'efficientnet-b1': (1.0, 1.1, 240, 0.2),
    'efficientnet-b2': (1.1, 1.2, 260, 0.3),
    'efficientnet-b3': (1.2, 1.4, 300, 0.3),
    'efficientnet-b4': (1.4, 1.8, 380, 0.4),
    'efficientnet-b5': (1.6, 2.2, 456, 0.4),
    'efficientnet-b6': (1.8, 2.6, 528, 0.5),
    'efficientnet-b7': (2.0, 3.1, 600, 0.5),
    'efficientnet-b8': (2.2, 3.6, 672, 0.5),

    Args:
      channel_multiplier: multiplier to number of channels per layer
      depth_multiplier: multiplier to number of repeats per stage

    """
    arch_def = [
        ['ds_r1_k3_s1_e1_c16_se0.25'],
        ['ir_r2_k3_s2_e6_c24_se0.25'],
        ['ir_r2_k5_s2_e6_c40_se0.25'],
        ['ir_r3_k3_s2_e6_c80_se0.25'],
        ['ir_r3_k5_s1_e6_c112_se0.25'],
        ['ir_r4_k5_s2_e6_c192_se0.25'],
        ['ir_r1_k3_s1_e6_c320_se0.25'],
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def, depth_multiplier),
        num_features=round_channels(1280, channel_multiplier, 8, None),
        stem_size=32,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'swish'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs,
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
    arch_def = [
        # NOTE `fc` is present to override a mismatch between stem channels and in chs not
        # present in other models
        ['er_r1_k3_s1_e4_c24_fc24_noskip'],
        ['er_r2_k3_s2_e8_c32'],
        ['er_r4_k3_s2_e8_c48'],
        ['ir_r5_k5_s2_e8_c96'],
        ['ir_r4_k5_s1_e8_c144'],
        ['ir_r2_k5_s2_e8_c192'],
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def, depth_multiplier),
        num_features=round_channels(1280, channel_multiplier, 8, None),
        stem_size=32,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'relu'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs,
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_efficientnet_condconv(
        variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
    """Creates an efficientnet-condconv model."""
    arch_def = [
      ['ds_r1_k3_s1_e1_c16_se0.25'],
      ['ir_r2_k3_s2_e6_c24_se0.25'],
      ['ir_r2_k5_s2_e6_c40_se0.25'],
      ['ir_r3_k3_s2_e6_c80_se0.25'],
      ['ir_r3_k5_s1_e6_c112_se0.25_cc4'],
      ['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
      ['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
        num_features=round_channels(1280, channel_multiplier, 8, None),
        stem_size=32,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'swish'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs,
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
    """Creates an EfficientNet-Lite model.

    Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
    Paper: https://arxiv.org/abs/1905.11946

    EfficientNet params
    name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
      'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
      'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
      'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
      'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
      'efficientnet-lite4': (1.4, 1.8, 300, 0.3),

    Args:
      channel_multiplier: multiplier to number of channels per layer
      depth_multiplier: multiplier to number of repeats per stage
    """
    arch_def = [
        ['ds_r1_k3_s1_e1_c16'],
        ['ir_r2_k3_s2_e6_c24'],
        ['ir_r2_k5_s2_e6_c40'],
        ['ir_r3_k3_s2_e6_c80'],
        ['ir_r3_k5_s1_e6_c112'],
        ['ir_r4_k5_s2_e6_c192'],
        ['ir_r1_k3_s1_e6_c320'],
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
        num_features=1280,
        stem_size=32,
        fix_stem=True,
        channel_multiplier=channel_multiplier,
        act_layer=nn.ReLU6,
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs,
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
    """Creates a MixNet Small model.

    Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
    Paper: https://arxiv.org/abs/1907.09595
    """
    arch_def = [
        # stage 0, 112x112 in
        ['ds_r1_k3_s1_e1_c16'],  # relu
        # stage 1, 112x112 in
        ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'],  # relu
        # stage 2, 56x56 in
        ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'],  # swish
        # stage 3, 28x28 in
        ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'],  # swish
        # stage 4, 14x14in
        ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'],  # swish
        # stage 5, 14x14in
        ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'],  # swish
        # 7x7
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def),
        num_features=1536,
        stem_size=16,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'relu'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
    """Creates a MixNet Medium-Large model.

    Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
    Paper: https://arxiv.org/abs/1907.09595
    """
    arch_def = [
        # stage 0, 112x112 in
        ['ds_r1_k3_s1_e1_c24'],  # relu
        # stage 1, 112x112 in
        ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'],  # relu
        # stage 2, 56x56 in
        ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'],  # swish
        # stage 3, 28x28 in
        ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'],  # swish
        # stage 4, 14x14in
        ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'],  # swish
        # stage 5, 14x14in
        ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'],  # swish
        # 7x7
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
        num_features=1536,
        stem_size=24,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'relu'),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def mnasnet_050(pretrained=False, **kwargs):
    """ MNASNet B1, depth multiplier of 0.5. """
    model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs)
    return model


def mnasnet_075(pretrained=False, **kwargs):
    """ MNASNet B1, depth multiplier of 0.75. """
    model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs)
    return model


def mnasnet_100(pretrained=False, **kwargs):
    """ MNASNet B1, depth multiplier of 1.0. """
    model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def mnasnet_b1(pretrained=False, **kwargs):
    """ MNASNet B1, depth multiplier of 1.0. """
    return mnasnet_100(pretrained, **kwargs)


def mnasnet_140(pretrained=False, **kwargs):
    """ MNASNet B1,  depth multiplier of 1.4 """
    model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs)
    return model


def semnasnet_050(pretrained=False, **kwargs):
    """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """
    model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs)
    return model


def semnasnet_075(pretrained=False, **kwargs):
    """ MNASNet A1 (w/ SE),  depth multiplier of 0.75. """
    model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs)
    return model


def semnasnet_100(pretrained=False, **kwargs):
    """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
    model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def mnasnet_a1(pretrained=False, **kwargs):
    """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
    return semnasnet_100(pretrained, **kwargs)


def semnasnet_140(pretrained=False, **kwargs):
    """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """
    model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs)
    return model


def mnasnet_small(pretrained=False, **kwargs):
    """ MNASNet Small,  depth multiplier of 1.0. """
    model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs)
    return model


def fbnetc_100(pretrained=False, **kwargs):
    """ FBNet-C """
    if pretrained:
        # pretrained model trained with non-default BN epsilon
        kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def spnasnet_100(pretrained=False, **kwargs):
    """ Single-Path NAS Pixel1"""
    model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def efficientnet_b0(pretrained=False, **kwargs):
    """ EfficientNet-B0 """
    # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
    model = _gen_efficientnet(
        'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def efficientnet_b1(pretrained=False, **kwargs):
    """ EfficientNet-B1 """
    # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
    model = _gen_efficientnet(
        'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
    return model


def efficientnet_b2(pretrained=False, **kwargs):
    """ EfficientNet-B2 """
    # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
    model = _gen_efficientnet(
        'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
    return model


def efficientnet_b3(pretrained=False, **kwargs):
    """ EfficientNet-B3 """
    # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
    model = _gen_efficientnet(
        'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
    return model


def efficientnet_b4(pretrained=False, **kwargs):
    """ EfficientNet-B4 """
    # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
    model = _gen_efficientnet(
        'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
    return model


def efficientnet_b5(pretrained=False, **kwargs):
    """ EfficientNet-B5 """
    # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
    model = _gen_efficientnet(
        'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
    return model


def efficientnet_b6(pretrained=False, **kwargs):
    """ EfficientNet-B6 """
    # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
    model = _gen_efficientnet(
        'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
    return model


def efficientnet_b7(pretrained=False, **kwargs):
    """ EfficientNet-B7 """
    # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
    model = _gen_efficientnet(
        'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
    return model


def efficientnet_b8(pretrained=False, **kwargs):
    """ EfficientNet-B8 """
    # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
    model = _gen_efficientnet(
        'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
    return model


def efficientnet_l2(pretrained=False, **kwargs):
    """ EfficientNet-L2. """
    # NOTE for train, drop_rate should be 0.5
    model = _gen_efficientnet(
        'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
    return model


def efficientnet_es(pretrained=False, **kwargs):
    """ EfficientNet-Edge Small. """
    model = _gen_efficientnet_edge(
        'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def efficientnet_em(pretrained=False, **kwargs):
    """ EfficientNet-Edge-Medium. """
    model = _gen_efficientnet_edge(
        'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
    return model


def efficientnet_el(pretrained=False, **kwargs):
    """ EfficientNet-Edge-Large. """
    model = _gen_efficientnet_edge(
        'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
    return model


def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
    """ EfficientNet-CondConv-B0 w/ 8 Experts """
    # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
    model = _gen_efficientnet_condconv(
        'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
    """ EfficientNet-CondConv-B0 w/ 8 Experts """
    # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
    model = _gen_efficientnet_condconv(
        'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
        pretrained=pretrained, **kwargs)
    return model


def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
    """ EfficientNet-CondConv-B1 w/ 8 Experts """
    # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
    model = _gen_efficientnet_condconv(
        'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
        pretrained=pretrained, **kwargs)
    return model


def efficientnet_lite0(pretrained=False, **kwargs):
    """ EfficientNet-Lite0 """
    model = _gen_efficientnet_lite(
        'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def efficientnet_lite1(pretrained=False, **kwargs):
    """ EfficientNet-Lite1 """
    model = _gen_efficientnet_lite(
        'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
    return model


def efficientnet_lite2(pretrained=False, **kwargs):
    """ EfficientNet-Lite2 """
    model = _gen_efficientnet_lite(
        'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
    return model


def efficientnet_lite3(pretrained=False, **kwargs):
    """ EfficientNet-Lite3 """
    model = _gen_efficientnet_lite(
        'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
    return model


def efficientnet_lite4(pretrained=False, **kwargs):
    """ EfficientNet-Lite4 """
    model = _gen_efficientnet_lite(
        'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b0(pretrained=False, **kwargs):
    """ EfficientNet-B0 AutoAug. Tensorflow compatible variant  """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b1(pretrained=False, **kwargs):
    """ EfficientNet-B1 AutoAug. Tensorflow compatible variant  """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b2(pretrained=False, **kwargs):
    """ EfficientNet-B2 AutoAug. Tensorflow compatible variant  """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b3(pretrained=False, **kwargs):
    """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b4(pretrained=False, **kwargs):
    """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b5(pretrained=False, **kwargs):
    """ EfficientNet-B5 RandAug. Tensorflow compatible variant """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b6(pretrained=False, **kwargs):
    """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b7(pretrained=False, **kwargs):
    """ EfficientNet-B7 RandAug. Tensorflow compatible variant """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b8(pretrained=False, **kwargs):
    """ EfficientNet-B8 RandAug. Tensorflow compatible variant """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b0_ap(pretrained=False, **kwargs):
    """ EfficientNet-B0 AdvProp. Tensorflow compatible variant
    Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b1_ap(pretrained=False, **kwargs):
    """ EfficientNet-B1 AdvProp. Tensorflow compatible variant
    Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b2_ap(pretrained=False, **kwargs):
    """ EfficientNet-B2 AdvProp. Tensorflow compatible variant
    Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b3_ap(pretrained=False, **kwargs):
    """ EfficientNet-B3 AdvProp. Tensorflow compatible variant
    Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b4_ap(pretrained=False, **kwargs):
    """ EfficientNet-B4 AdvProp. Tensorflow compatible variant
    Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b5_ap(pretrained=False, **kwargs):
    """ EfficientNet-B5 AdvProp. Tensorflow compatible variant
    Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b6_ap(pretrained=False, **kwargs):
    """ EfficientNet-B6 AdvProp. Tensorflow compatible variant
    Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
    """
    # NOTE for train, drop_rate should be 0.5
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b7_ap(pretrained=False, **kwargs):
    """ EfficientNet-B7 AdvProp. Tensorflow compatible variant
    Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
    """
    # NOTE for train, drop_rate should be 0.5
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b8_ap(pretrained=False, **kwargs):
    """ EfficientNet-B8 AdvProp. Tensorflow compatible variant
    Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
    """
    # NOTE for train, drop_rate should be 0.5
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b0_ns(pretrained=False, **kwargs):
    """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b1_ns(pretrained=False, **kwargs):
    """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b2_ns(pretrained=False, **kwargs):
    """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b3_ns(pretrained=False, **kwargs):
    """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b4_ns(pretrained=False, **kwargs):
    """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b5_ns(pretrained=False, **kwargs):
    """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b6_ns(pretrained=False, **kwargs):
    """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    # NOTE for train, drop_rate should be 0.5
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_b7_ns(pretrained=False, **kwargs):
    """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    # NOTE for train, drop_rate should be 0.5
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs):
    """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    # NOTE for train, drop_rate should be 0.5
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_l2_ns(pretrained=False, **kwargs):
    """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant
    Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
    """
    # NOTE for train, drop_rate should be 0.5
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet(
        'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_es(pretrained=False, **kwargs):
    """ EfficientNet-Edge Small. Tensorflow compatible variant  """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_edge(
        'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_em(pretrained=False, **kwargs):
    """ EfficientNet-Edge-Medium. Tensorflow compatible variant  """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_edge(
        'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_el(pretrained=False, **kwargs):
    """ EfficientNet-Edge-Large. Tensorflow compatible variant  """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_edge(
        'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
    """ EfficientNet-CondConv-B0 w/ 4 Experts """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_condconv(
        'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
    """ EfficientNet-CondConv-B0 w/ 8 Experts """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_condconv(
        'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
        pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
    """ EfficientNet-CondConv-B1 w/ 8 Experts """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_condconv(
        'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
        pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_lite0(pretrained=False, **kwargs):
    """ EfficientNet-Lite0. Tensorflow compatible variant  """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_lite(
        'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_lite1(pretrained=False, **kwargs):
    """ EfficientNet-Lite1. Tensorflow compatible variant  """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_lite(
        'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_lite2(pretrained=False, **kwargs):
    """ EfficientNet-Lite2. Tensorflow compatible variant  """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_lite(
        'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_lite3(pretrained=False, **kwargs):
    """ EfficientNet-Lite3. Tensorflow compatible variant """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_lite(
        'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
    return model


def tf_efficientnet_lite4(pretrained=False, **kwargs):
    """ EfficientNet-Lite4. Tensorflow compatible variant """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_efficientnet_lite(
        'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
    return model


def mixnet_s(pretrained=False, **kwargs):
    """Creates a MixNet Small model.
    """
    # NOTE for train set drop_rate=0.2
    model = _gen_mixnet_s(
        'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def mixnet_m(pretrained=False, **kwargs):
    """Creates a MixNet Medium model.
    """
    # NOTE for train set drop_rate=0.25
    model = _gen_mixnet_m(
        'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def mixnet_l(pretrained=False, **kwargs):
    """Creates a MixNet Large model.
    """
    # NOTE for train set drop_rate=0.25
    model = _gen_mixnet_m(
        'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
    return model


def mixnet_xl(pretrained=False, **kwargs):
    """Creates a MixNet Extra-Large model.
    Not a paper spec, experimental def by RW w/ depth scaling.
    """
    # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
    model = _gen_mixnet_m(
        'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
    return model


def mixnet_xxl(pretrained=False, **kwargs):
    """Creates a MixNet Double Extra Large model.
    Not a paper spec, experimental def by RW w/ depth scaling.
    """
    # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
    model = _gen_mixnet_m(
        'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs)
    return model


def tf_mixnet_s(pretrained=False, **kwargs):
    """Creates a MixNet Small model. Tensorflow compatible variant
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_mixnet_s(
        'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def tf_mixnet_m(pretrained=False, **kwargs):
    """Creates a MixNet Medium model. Tensorflow compatible variant
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_mixnet_m(
        'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model


def tf_mixnet_l(pretrained=False, **kwargs):
    """Creates a MixNet Large model. Tensorflow compatible variant
    """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_mixnet_m(
        'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
    return model


================================================
FILE: models/helpers.py
================================================
import torch
import os
from collections import OrderedDict
try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url


def load_checkpoint(model, checkpoint_path):
    if checkpoint_path and os.path.isfile(checkpoint_path):
        print("=> Loading checkpoint '{}'".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path)
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict'].items():
                if k.startswith('module'):
                    name = k[7:]  # remove `module.`
                else:
                    name = k
                new_state_dict[name] = v
            model.load_state_dict(new_state_dict)
        else:
            model.load_state_dict(checkpoint)
        print("=> Loaded checkpoint '{}'".format(checkpoint_path))
    else:
        print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
        raise FileNotFoundError()


def load_pretrained(model, url, filter_fn=None, strict=True):
    if not url:
        print("=> Warning: Pretrained model URL is empty, using random initialization.")
        return

    state_dict = torch.load(url, map_location='cpu')

    input_conv = 'conv_stem'
    classifier = 'classifier'
    in_chans = getattr(model, input_conv).weight.shape[1]
    num_classes = getattr(model, classifier).weight.shape[0]

    input_conv_weight = input_conv + '.weight'
    pretrained_in_chans = state_dict[input_conv_weight].shape[1]
    if in_chans != pretrained_in_chans:
        if in_chans == 1:
            print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
                input_conv_weight, pretrained_in_chans))
            conv1_weight = state_dict[input_conv_weight]
            state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
        else:
            print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
                input_conv_weight, pretrained_in_chans))
            del state_dict[input_conv_weight]
            strict = False

    classifier_weight = classifier + '.weight'
    pretrained_num_classes = state_dict[classifier_weight].shape[0]
    if num_classes != pretrained_num_classes:
        print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
        del state_dict[classifier_weight]
        del state_dict[classifier + '.bias']
        strict = False

    if filter_fn is not None:
        state_dict = filter_fn(state_dict)

    model.load_state_dict(state_dict, strict=strict)


================================================
FILE: models/mobilenetv3.py
================================================
""" MobileNet-V3

A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.

Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244

Hacked together by Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F

from .helpers import load_pretrained
from .efficientnet_builder import *

__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
           'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',
           'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100',
           'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100', 'mobilenetv3_large_125']

model_urls = {
    'mobilenetv3_rw':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
    'mobilenetv3_large_075': None,
    'mobilenetv3_large_100': None,
    'mobilenetv3_large_125': None,
    'mobilenetv3_large_minimal_100': None,
    'mobilenetv3_small_075': None,
    'mobilenetv3_small_100': None,
    'mobilenetv3_small_minimal_100': None,
    'tf_mobilenetv3_large_075':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
    'tf_mobilenetv3_large_100':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
    'tf_mobilenetv3_large_minimal_100':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
    'tf_mobilenetv3_small_075':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
    'tf_mobilenetv3_small_100':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
    'tf_mobilenetv3_small_minimal_100':
        'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
}


class MobileNetV3(nn.Module):
    """ MobileNet-V3

    A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the
    head convolution without a final batch-norm layer before the classifier.

    Paper: https://arxiv.org/abs/1905.02244
    """

    def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
                 channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0.,
                 se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
        super(MobileNetV3, self).__init__()
        self.drop_rate = drop_rate

        stem_size = round_channels(stem_size, channel_multiplier)
        self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
        self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs)
        self.act1 = act_layer(inplace=True)
        in_chs = stem_size

        builder = EfficientNetBuilder(
            channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs,
            norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate)
        self.blocks = nn.Sequential(*builder(in_chs, block_args))
        in_chs = builder.in_chs

        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias)
        self.act2 = act_layer(inplace=True)

        self.feature_num = num_features
        print(num_features)
        self.classifier = nn.Linear(num_features, num_classes)

        for m in self.modules():
            if weight_init == 'goog':
                initialize_weight_goog(m)
            else:
                initialize_weight_default(m)

    def as_sequential(self):
        layers = [self.conv_stem, self.bn1, self.act1]
        layers.extend(self.blocks)
        layers.extend([
            self.global_pool, self.conv_head, self.act2,
            nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
        return nn.Sequential(*layers)

    def features(self, x):
        x = self.conv_stem(x)
        x = self.bn1(x)
        x = self.act1(x)
        features = self.blocks(x)
        x = self.global_pool(features)
        x = self.conv_head(x)
        x = self.act2(x)
        return x, features.detach()

    def forward(self, x):
        x, features = self.features(x)
        x = x.flatten(1)
        if self.drop_rate > 0.:
            x = F.dropout(x, p=self.drop_rate, training=self.training)
        return (x), features.detach()


def _create_model(model_kwargs, variant, pretrained=False):
    as_sequential = model_kwargs.pop('as_sequential', False)
    model = MobileNetV3(**model_kwargs)
    if pretrained and model_urls[variant]:
        load_pretrained(model, model_urls[variant])
    if as_sequential:
        model = model.as_sequential()
    return model


def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
    """Creates a MobileNet-V3 model (RW variant).

    Paper: https://arxiv.org/abs/1905.02244

    This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the
    eventual Tensorflow reference impl but has a few differences:
    1. This model has no bias on the head convolution
    2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet
    3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer
       from their parent block
    4. This model does not enforce divisible by 8 limitation on the SE reduction channel count

    Overall the changes are fairly minor and result in a very small parameter count difference and no
    top-1/5

    Args:
      channel_multiplier: multiplier to number of channels per layer.
    """
    arch_def = [
        # stage 0, 112x112 in
        ['ds_r1_k3_s1_e1_c16_nre_noskip'],  # relu
        # stage 1, 112x112 in
        ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'],  # relu
        # stage 2, 56x56 in
        ['ir_r3_k5_s2_e3_c40_se0.25_nre'],  # relu
        # stage 3, 28x28 in
        ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],  # hard-swish
        # stage 4, 14x14in
        ['ir_r2_k3_s1_e6_c112_se0.25'],  # hard-swish
        # stage 5, 14x14in
        ['ir_r3_k5_s2_e6_c160_se0.25'],  # hard-swish
        # stage 6, 7x7 in
        ['cn_r1_k1_s1_c960'],  # hard-swish
    ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def),
        head_bias=False,  # one of my mistakes
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, 'hard_swish'),
        se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs,
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
    """Creates a MobileNet-V3 large/small/minimal models.

    Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
    Paper: https://arxiv.org/abs/1905.02244

    Args:
      channel_multiplier: multiplier to number of channels per layer.
    """
    if 'small' in variant:
        num_features = 1024
        if 'minimal' in variant:
            act_layer = 'relu'
            arch_def = [
                # stage 0, 112x112 in
                ['ds_r1_k3_s2_e1_c16'],
                # stage 1, 56x56 in
                ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
                # stage 2, 28x28 in
                ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
                # stage 3, 14x14 in
                ['ir_r2_k3_s1_e3_c48'],
                # stage 4, 14x14in
                ['ir_r3_k3_s2_e6_c96'],
                # stage 6, 7x7 in
                ['cn_r1_k1_s1_c576'],
            ]
        else:
            act_layer = 'hard_swish'
            arch_def = [
                # stage 0, 112x112 in
                ['ds_r1_k3_s2_e1_c16_se0.25_nre'],  # relu
                # stage 1, 56x56 in
                ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'],  # relu
                # stage 2, 28x28 in
                ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'],  # hard-swish
                # stage 3, 14x14 in
                ['ir_r2_k5_s1_e3_c48_se0.25'],  # hard-swish
                # stage 4, 14x14in
                ['ir_r3_k5_s2_e6_c96_se0.25'],  # hard-swish
                # stage 6, 7x7 in
                ['cn_r1_k1_s1_c576'],  # hard-swish
            ]
    else:
        num_features = 1280
        if 'minimal' in variant:
            act_layer = 'relu'
            arch_def = [
                # stage 0, 112x112 in
                ['ds_r1_k3_s1_e1_c16'],
                # stage 1, 112x112 in
                ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
                # stage 2, 56x56 in
                ['ir_r3_k3_s2_e3_c40'],
                # stage 3, 28x28 in
                ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
                # stage 4, 14x14in
                ['ir_r2_k3_s1_e6_c112'],
                # stage 5, 14x14in
                ['ir_r3_k3_s2_e6_c160'],
                # stage 6, 7x7 in
                ['cn_r1_k1_s1_c960'],
            ]
        else:
            act_layer = 'hard_swish'
            arch_def = [
                # stage 0, 112x112 in
                ['ds_r1_k3_s1_e1_c16_nre'],  # relu
                # stage 1, 112x112 in
                ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'],  # relu
                # stage 2, 56x56 in
                ['ir_r3_k5_s2_e3_c40_se0.25_nre'],  # relu
                # stage 3, 28x28 in
                ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],  # hard-swish
                # stage 4, 14x14in
                ['ir_r2_k3_s1_e6_c112_se0.25'],  # hard-swish
                # stage 5, 14x14in
                ['ir_r3_k5_s2_e6_c160_se0.25'],  # hard-swish
                # stage 6, 7x7 in
                ['cn_r1_k1_s1_c960'],  # hard-swish
            ]
    model_kwargs = dict(
        block_args=decode_arch_def(arch_def),
        num_features=num_features,
        stem_size=16,
        channel_multiplier=channel_multiplier,
        act_layer=resolve_act_layer(kwargs, act_layer),
        se_kwargs=dict(
            act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8),
        norm_kwargs=resolve_bn_args(kwargs),
        **kwargs,
    )
    model = _create_model(model_kwargs, variant, pretrained)
    return model


def mobilenetv3_rw(pretrained=False, **kwargs):
    """ MobileNet-V3 RW
    Attn: See note in gen function for this variant.
    """
    # NOTE for train set drop_rate=0.2
    if pretrained:
        # pretrained model trained with non-default BN epsilon
        kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
    return model


def mobilenetv3_large_075(pretrained=False, **kwargs):
    """ MobileNet V3 Large 0.75"""
    # NOTE for train set drop_rate=0.2
    model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
    return model


def mobilenetv3_large_100(pretrained=False, **kwargs):
    """ MobileNet V3 Large 1.0 """
    # NOTE for train set drop_rate=0.2
    model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def mobilenetv3_large_125(pretrained=False, **kwargs):
    """ MobileNet V3 Large 1.25 """
    # NOTE for train set drop_rate=0.2
    model = _gen_mobilenet_v3('mobilenetv3_large_125', 1.25, pretrained=pretrained, **kwargs)
    return model



def mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
    """ MobileNet V3 Large (Minimalistic) 1.0 """
    # NOTE for train set drop_rate=0.2
    model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def mobilenetv3_small_075(pretrained=False, **kwargs):
    """ MobileNet V3 Small 0.75 """
    model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
    return model


def mobilenetv3_small_100(pretrained=False, **kwargs):
    """ MobileNet V3 Small 1.0 """
    model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
    """ MobileNet V3 Small (Minimalistic) 1.0 """
    model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
    """ MobileNet V3 Large 0.75. Tensorflow compat variant. """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
    return model


def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
    """ MobileNet V3 Large 1.0. Tensorflow compat variant. """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
    """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
    """ MobileNet V3 Small 0.75. Tensorflow compat variant. """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
    return model


def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
    """ MobileNet V3 Small 1.0. Tensorflow compat variant."""
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
    return model


def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
    """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """
    kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
    kwargs['pad_type'] = 'same'
    model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
    return model


================================================
FILE: models/model_factory.py
================================================
from .mobilenetv3 import *
from .gen_efficientnet import *
from .helpers import load_checkpoint


def create_model(
        model_name='mnasnet_100',
        pretrained=None,
        num_classes=1000,
        in_chans=3,
        checkpoint_path='',
        **kwargs):

    margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)

    if model_name in globals():
        create_fn = globals()[model_name]
        model = create_fn(**margs, **kwargs)
    else:
        raise RuntimeError('Unknown model (%s)' % model_name)

    if checkpoint_path and not pretrained:
        load_checkpoint(model, checkpoint_path)

    return model


================================================
FILE: models/resnet.py
================================================
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']


model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, groups=groups, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        self.inplanes = 64
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        output = self.avgpool(x)
        output = output.view(x.size(0), -1)
        return output, x.detach()


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))

    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))

    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model


def resnext50_32x4d(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d']))
    return model


def resnext101_32x8d(pretrained=False, **kwargs):
    model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d']))
    return model


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


================================================
FILE: models/version.py
================================================
__version__ = '0.9.8'


================================================
FILE: network.py
================================================
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import math


class Memory:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []
        self.hidden = []

    def clear_memory(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]
        del self.hidden[:]


class ActorCritic(nn.Module):
    def __init__(self, feature_dim, state_dim, hidden_state_dim=1024, policy_conv=True, action_std=0.1):
        super(ActorCritic, self).__init__()
        
        # encoder with convolution layer for MobileNetV3, EfficientNet and RegNet
        if policy_conv:
            self.state_encoder = nn.Sequential(
                nn.Conv2d(feature_dim, 32, kernel_size=1, stride=1, padding=0, bias=False),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(int(state_dim * 32 / feature_dim), hidden_state_dim),
                nn.ReLU()
            )
        # encoder with linear layer for ResNet and DenseNet
        else:
            self.state_encoder = nn.Sequential(
                nn.Linear(state_dim, 2048),
                nn.ReLU(),
                nn.Linear(2048, hidden_state_dim),
                nn.ReLU()
            )

        self.gru = nn.GRU(hidden_state_dim, hidden_state_dim, batch_first=False)
        
        self.actor = nn.Sequential(
            nn.Linear(hidden_state_dim, 2),
            nn.Sigmoid())

        self.critic = nn.Sequential(
            nn.Linear(hidden_state_dim, 1))

        self.action_var = torch.full((2,), action_std).cuda()

        self.hidden_state_dim = hidden_state_dim
        self.policy_conv = policy_conv
        self.feature_dim = feature_dim
        self.feature_ratio = int(math.sqrt(state_dim/feature_dim))

    def forward(self):
        raise NotImplementedError

    def act(self, state_ini, memory, restart_batch=False, training=False):
        if restart_batch:
            del memory.hidden[:]
            memory.hidden.append(torch.zeros(1, state_ini.size(0), self.hidden_state_dim).cuda())

        if not self.policy_conv:
            state = state_ini.flatten(1)
        else:
            state = state_ini

        state = self.state_encoder(state)

        state, hidden_output = self.gru(state.view(1, state.size(0), state.size(1)), memory.hidden[-1])
        memory.hidden.append(hidden_output)

        state = state[0]
        action_mean = self.actor(state)

        cov_mat = torch.diag(self.action_var).cuda()
        dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat)
        action = dist.sample().cuda()
        if training:
            action = F.relu(action)
            action = 1 - F.relu(1 - action)
            action_logprob = dist.log_prob(action).cuda()
            memory.states.append(state_ini)
            memory.actions.append(action)
            memory.logprobs.append(action_logprob)
        else:
            action = action_mean

        return action.detach()

    def evaluate(self, state, action):
        seq_l = state.size(0)
        batch_size = state.size(1)

        if not self.policy_conv:
            state = state.flatten(2)
            state = state.view(seq_l * batch_size, state.size(2))
        else:
            state = state.view(seq_l * batch_size, state.size(2), state.size(3), state.size(4))

        state = self.state_encoder(state)
        state = state.view(seq_l, batch_size, -1)

        state, hidden = self.gru(state, torch.zeros(1, batch_size, state.size(2)).cuda())
        state = state.view(seq_l * batch_size, -1)

        action_mean = self.actor(state)

        cov_mat = torch.diag(self.action_var).cuda()

        dist = torch.distributions.multivariate_normal.MultivariateNormal(action_mean, scale_tril=cov_mat)

        action_logprobs = dist.log_prob(torch.squeeze(action.view(seq_l * batch_size, -1))).cuda()
        dist_entropy = dist.entropy().cuda()
        state_value = self.critic(state)

        return action_logprobs.view(seq_l, batch_size), \
               state_value.view(seq_l, batch_size), \
               dist_entropy.view(seq_l, batch_size)


class PPO:
    def __init__(self, feature_dim, state_dim, hidden_state_dim, policy_conv,
                 action_std=0.1, lr=0.0003, betas=(0.9, 0.999), gamma=0.7, K_epochs=1, eps_clip=0.2):
        self.lr = lr
        self.betas = betas
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs

        self.policy = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda()

        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)

        self.policy_old = ActorCritic(feature_dim, state_dim, hidden_state_dim, policy_conv, action_std).cuda()
        self.policy_old.load_state_dict(self.policy.state_dict())

        self.MseLoss = nn.MSELoss()

    def select_action(self, state, memory, restart_batch=False, training=True):
        return self.policy_old.act(state, memory, restart_batch, training)

    def update(self, memory):
        rewards = []
        discounted_reward = 0

        for reward in reversed(memory.rewards):
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)

        rewards = torch.cat(rewards, 0).cuda()

        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)

        old_states = torch.stack(memory.states, 0).cuda().detach()
        old_actions = torch.stack(memory.actions, 0).cuda().detach()
        old_logprobs = torch.stack(memory.logprobs, 0).cuda().detach()

        for _ in range(self.K_epochs):
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)

            ratios = torch.exp(logprobs - old_logprobs.detach())

            advantages = rewards - state_values.detach()
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

            loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, rewards) - 0.01 * dist_entropy

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

        self.policy_old.load_state_dict(self.policy.state_dict())


class Full_layer(torch.nn.Module):
    def __init__(self, feature_num, hidden_state_dim=1024, fc_rnn=True, class_num=1000):
        super(Full_layer, self).__init__()
        self.class_num = class_num
        self.feature_num = feature_num

        self.hidden_state_dim = hidden_state_dim
        self.hidden = None
        self.fc_rnn = fc_rnn
        
        # classifier with RNN for ResNet, DenseNet and RegNet
        if fc_rnn:
            self.rnn = nn.GRU(feature_num, self.hidden_state_dim)
            self.fc = nn.Linear(self.hidden_state_dim, class_num)
        # cascaded classifier for MobileNetV3 and EfficientNet
        else:
            self.fc_2 = nn.Linear(self.feature_num * 2, class_num)
            self.fc_3 = nn.Linear(self.feature_num * 3, class_num)
            self.fc_4 = nn.Linear(self.feature_num * 4, class_num)
            self.fc_5 = nn.Linear(self.feature_num * 5, class_num)

    def forward(self, x, restart=False):

        if self.fc_rnn:
            if restart:
                output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), torch.zeros(1, x.size(0), self.hidden_state_dim).cuda())
                self.hidden = h_n
            else:
                output, h_n = self.rnn(x.view(1, x.size(0), x.size(1)), self.hidden)
                self.hidden = h_n

            return self.fc(output[0])
        else:
            if restart:
                self.hidden = x
            else:
                self.hidden = torch.cat([self.hidden, x], 1)

            if self.hidden.size(1) == self.feature_num:
                return None
            elif self.hidden.size(1) == self.feature_num * 2:
                return self.fc_2(self.hidden)
            elif self.hidden.size(1) == self.feature_num * 3:
                return self.fc_3(self.hidden)
            elif self.hidden.size(1) == self.feature_num * 4:
                return self.fc_4(self.hidden)
            elif self.hidden.size(1) == self.feature_num * 5:
                return self.fc_5(self.hidden)
            else:
                print(self.hidden.size())
                exit()

================================================
FILE: pycls/__init__.py
================================================


================================================
FILE: pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml
================================================
MODEL:
  TYPE: regnet
  NUM_CLASSES: 1000
REGNET:
  SE_ON: True
  DEPTH: 27
  W0: 48
  WA: 20.71
  WM: 2.65
  GROUP_W: 24
OPTIM:
  LR_POLICY: cos
  BASE_LR: 0.8
  MAX_EPOCH: 100
  MOMENTUM: 0.9
  WEIGHT_DECAY: 5e-5
  WARMUP_EPOCHS: 5
TRAIN:
  DATASET: imagenet
  IM_SIZE: 224
  BATCH_SIZE: 1024
TEST:
  DATASET: imagenet
  IM_SIZE: 256
  BATCH_SIZE: 800
NUM_GPUS: 1
OUT_DIR: .


================================================
FILE: pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml
================================================
MODEL:
  TYPE: regnet
  NUM_CLASSES: 1000
REGNET:
  SE_ON: True
  DEPTH: 15
  W0: 48
  WA: 32.54
  WM: 2.32
  GROUP_W: 16
OPTIM:
  LR_POLICY: cos
  BASE_LR: 0.8
  MAX_EPOCH: 100
  MOMENTUM: 0.9
  WEIGHT_DECAY: 5e-5
  WARMUP_EPOCHS: 5
TRAIN:
  DATASET: imagenet
  IM_SIZE: 224
  BATCH_SIZE: 1024
TEST:
  DATASET: imagenet
  IM_SIZE: 256
  BATCH_SIZE: 800
NUM_GPUS: 1
OUT_DIR: .


================================================
FILE: pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml
================================================
MODEL:
  TYPE: regnet
  NUM_CLASSES: 1000
REGNET:
  SE_ON: True
  DEPTH: 14
  W0: 56
  WA: 38.84
  WM: 2.4
  GROUP_W: 16
OPTIM:
  LR_POLICY: cos
  BASE_LR: 0.8
  MAX_EPOCH: 100
  MOMENTUM: 0.9
  WEIGHT_DECAY: 5e-5
  WARMUP_EPOCHS: 5
TRAIN:
  DATASET: imagenet
  IM_SIZE: 224
  BATCH_SIZE: 1024
TEST:
  DATASET: imagenet
  IM_SIZE: 256
  BATCH_SIZE: 800
NUM_GPUS: 1
OUT_DIR: .


================================================
FILE: pycls/core/__init__.py
================================================


================================================
FILE: pycls/core/config.py
================================================
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Configuration file (powered by YACS)."""

import os

from pycls.utils.io import cache_url
from yacs.config import CfgNode as CN


# Global config object
_C = CN()

# Example usage:
#   from core.config import cfg
cfg = _C


# ---------------------------------------------------------------------------- #
# Model options
# ---------------------------------------------------------------------------- #
_C.MODEL = CN()

# Model type
_C.MODEL.TYPE = ""

# Number of weight layers
_C.MODEL.DEPTH = 0

# Number of classes
_C.MODEL.NUM_CLASSES = 10

# Loss function (see pycls/models/loss.py for options)
_C.MODEL.LOSS_FUN = "cross_entropy"


# ---------------------------------------------------------------------------- #
# ResNet options
# ---------------------------------------------------------------------------- #
_C.RESNET = CN()

# Transformation function (see pycls/models/resnet.py for options)
_C.RESNET.TRANS_FUN = "basic_transform"

# Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
_C.RESNET.NUM_GROUPS = 1

# Width of each group (64 -> ResNet; 4 -> ResNeXt)
_C.RESNET.WIDTH_PER_GROUP = 64

# Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
_C.RESNET.STRIDE_1X1 = True


# ---------------------------------------------------------------------------- #
# AnyNet options
# ---------------------------------------------------------------------------- #
_C.ANYNET = CN()

# Stem type
_C.ANYNET.STEM_TYPE = "plain_block"

# Stem width
_C.ANYNET.STEM_W = 32

# Block type
_C.ANYNET.BLOCK_TYPE = "plain_block"

# Depth for each stage (number of blocks in the stage)
_C.ANYNET.DEPTHS = []

# Width for each stage (width of each block in the stage)
_C.ANYNET.WIDTHS = []

# Strides for each stage (applies to the first block of each stage)
_C.ANYNET.STRIDES = []

# Bottleneck multipliers for each stage (applies to bottleneck block)
_C.ANYNET.BOT_MULS = []

# Group widths for each stage (applies to bottleneck block)
_C.ANYNET.GROUP_WS = []

# Whether SE is enabled for res_bottleneck_block
_C.ANYNET.SE_ON = False

# SE ratio
_C.ANYNET.SE_R = 0.25

# ---------------------------------------------------------------------------- #
# RegNet options
# ---------------------------------------------------------------------------- #
_C.REGNET = CN()

# Stem type
_C.REGNET.STEM_TYPE = "simple_stem_in"
# Stem width
_C.REGNET.STEM_W = 32
# Block type
_C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
# Stride of each stage
_C.REGNET.STRIDE = 2
# Squeeze-and-Excitation (RegNetY)
_C.REGNET.SE_ON = False
_C.REGNET.SE_R = 0.25

# Depth
_C.REGNET.DEPTH = 10
# Initial width
_C.REGNET.W0 = 32
# Slope
_C.REGNET.WA = 5.0
# Quantization
_C.REGNET.WM = 2.5
# Group width
_C.REGNET.GROUP_W = 16
# Bottleneck multiplier (bm = 1 / b from the paper)
_C.REGNET.BOT_MUL = 1.0


# ---------------------------------------------------------------------------- #
# EfficientNet options
# ---------------------------------------------------------------------------- #
_C.EN = CN()

# Stem width
_C.EN.STEM_W = 32

# Depth for each stage (number of blocks in the stage)
_C.EN.DEPTHS = []

# Width for each stage (width of each block in the stage)
_C.EN.WIDTHS = []

# Expansion ratios for MBConv blocks in each stage
_C.EN.EXP_RATIOS = []

# Squeeze-and-Excitation (SE) ratio
_C.EN.SE_R = 0.25

# Strides for each stage (applies to the first block of each stage)
_C.EN.STRIDES = []

# Kernel sizes for each stage
_C.EN.KERNELS = []

# Head width
_C.EN.HEAD_W = 1280

# Drop connect ratio
_C.EN.DC_RATIO = 0.0

# Dropout ratio
_C.EN.DROPOUT_RATIO = 0.0


# ---------------------------------------------------------------------------- #
# Batch norm options
# ---------------------------------------------------------------------------- #
_C.BN = CN()

# BN epsilon
_C.BN.EPS = 1e-5

# BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
_C.BN.MOM = 0.1

# Precise BN stats
_C.BN.USE_PRECISE_STATS = False
_C.BN.NUM_SAMPLES_PRECISE = 1024

# Initialize the gamma of the final BN of each block to zero
_C.BN.ZERO_INIT_FINAL_GAMMA = False

# Use a different weight decay for BN layers
_C.BN.USE_CUSTOM_WEIGHT_DECAY = False
_C.BN.CUSTOM_WEIGHT_DECAY = 0.0

# ---------------------------------------------------------------------------- #
# Optimizer options
# ---------------------------------------------------------------------------- #
_C.OPTIM = CN()

# Base learning rate
_C.OPTIM.BASE_LR = 0.1

# Learning rate policy select from {'cos', 'exp', 'steps'}
_C.OPTIM.LR_POLICY = "cos"

# Exponential decay factor
_C.OPTIM.GAMMA = 0.1

# Steps for 'steps' policy (in epochs)
_C.OPTIM.STEPS = []

# Learning rate multiplier for 'steps' policy
_C.OPTIM.LR_MULT = 0.1

# Maximal number of epochs
_C.OPTIM.MAX_EPOCH = 200

# Momentum
_C.OPTIM.MOMENTU
Download .txt
gitextract_6lcfcbge/

├── .gitignore
├── LICENSE
├── README.md
├── configs.py
├── inference.py
├── models/
│   ├── __init__.py
│   ├── activations/
│   │   ├── __init__.py
│   │   ├── activations.py
│   │   ├── activations_autofn.py
│   │   ├── activations_jit.py
│   │   └── config.py
│   ├── config.py
│   ├── conv2d_layers.py
│   ├── densenet.py
│   ├── efficientnet_builder.py
│   ├── gen_efficientnet.py
│   ├── helpers.py
│   ├── mobilenetv3.py
│   ├── model_factory.py
│   ├── resnet.py
│   └── version.py
├── network.py
├── pycls/
│   ├── __init__.py
│   ├── cfgs/
│   │   ├── RegNetY-1.6GF_dds_8gpu.yaml
│   │   ├── RegNetY-600MF_dds_8gpu.yaml
│   │   └── RegNetY-800MF_dds_8gpu.yaml
│   ├── core/
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── losses.py
│   │   ├── model_builder.py
│   │   ├── old_config.py
│   │   └── optimizer.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── cifar10.py
│   │   ├── imagenet.py
│   │   ├── loader.py
│   │   ├── paths.py
│   │   └── transforms.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── anynet.py
│   │   ├── effnet.py
│   │   ├── regnet.py
│   │   └── resnet.py
│   └── utils/
│       ├── __init__.py
│       ├── benchmark.py
│       ├── checkpoint.py
│       ├── distributed.py
│       ├── error_handler.py
│       ├── io.py
│       ├── logging.py
│       ├── lr_policy.py
│       ├── meters.py
│       ├── metrics.py
│       ├── multiprocessing.py
│       ├── net.py
│       ├── plotting.py
│       └── timer.py
├── simplejson/
│   ├── __init__.py
│   ├── _speedups.c
│   ├── compat.py
│   ├── decoder.py
│   ├── encoder.py
│   ├── errors.py
│   ├── ordered_dict.py
│   ├── raw_json.py
│   ├── scanner.py
│   ├── tests/
│   │   ├── __init__.py
│   │   ├── test_bigint_as_string.py
│   │   ├── test_bitsize_int_as_string.py
│   │   ├── test_check_circular.py
│   │   ├── test_decimal.py
│   │   ├── test_decode.py
│   │   ├── test_default.py
│   │   ├── test_dump.py
│   │   ├── test_encode_basestring_ascii.py
│   │   ├── test_encode_for_html.py
│   │   ├── test_errors.py
│   │   ├── test_fail.py
│   │   ├── test_float.py
│   │   ├── test_for_json.py
│   │   ├── test_indent.py
│   │   ├── test_item_sort_key.py
│   │   ├── test_iterable.py
│   │   ├── test_namedtuple.py
│   │   ├── test_pass1.py
│   │   ├── test_pass2.py
│   │   ├── test_pass3.py
│   │   ├── test_raw_json.py
│   │   ├── test_recursion.py
│   │   ├── test_scanstring.py
│   │   ├── test_separators.py
│   │   ├── test_speedups.py
│   │   ├── test_str_subclass.py
│   │   ├── test_subclass.py
│   │   ├── test_tool.py
│   │   ├── test_tuple.py
│   │   └── test_unicode.py
│   └── tool.py
├── train.py
├── utils.py
└── yacs/
    ├── __init__.py
    ├── config.py
    └── tests.py
Download .txt
SYMBOL INDEX (986 symbols across 88 files)

FILE: inference.py
  function main (line 38) | def main():
  function generate_logits (line 176) | def generate_logits(model_prime, model, fc, memory, policy, dataloader, ...
  function dynamic_find_threshold (line 241) | def dynamic_find_threshold(logits, targets, p):
  function dynamic_evaluate (line 266) | def dynamic_evaluate(logits, targets, flops, T):

FILE: models/activations/__init__.py
  function add_override_act_fn (line 57) | def add_override_act_fn(name, fn):
  function update_override_act_fn (line 62) | def update_override_act_fn(overrides):
  function clear_override_act_fn (line 68) | def clear_override_act_fn():
  function add_override_act_layer (line 73) | def add_override_act_layer(name, fn):
  function update_override_act_layer (line 77) | def update_override_act_layer(overrides):
  function clear_override_act_layer (line 83) | def clear_override_act_layer():
  function get_act_fn (line 88) | def get_act_fn(name='relu'):
  function get_act_layer (line 106) | def get_act_layer(name='relu'):

FILE: models/activations/activations.py
  function swish (line 5) | def swish(x, inplace: bool = False):
  class Swish (line 11) | class Swish(nn.Module):
    method __init__ (line 12) | def __init__(self, inplace: bool = False):
    method forward (line 16) | def forward(self, x):
  function mish (line 20) | def mish(x, inplace: bool = False):
  class Mish (line 26) | class Mish(nn.Module):
    method __init__ (line 27) | def __init__(self, inplace: bool = False):
    method forward (line 31) | def forward(self, x):
  function sigmoid (line 35) | def sigmoid(x, inplace: bool = False):
  class Sigmoid (line 40) | class Sigmoid(nn.Module):
    method __init__ (line 41) | def __init__(self, inplace: bool = False):
    method forward (line 45) | def forward(self, x):
  function tanh (line 49) | def tanh(x, inplace: bool = False):
  class Tanh (line 54) | class Tanh(nn.Module):
    method __init__ (line 55) | def __init__(self, inplace: bool = False):
    method forward (line 59) | def forward(self, x):
  function hard_swish (line 63) | def hard_swish(x, inplace: bool = False):
  class HardSwish (line 68) | class HardSwish(nn.Module):
    method __init__ (line 69) | def __init__(self, inplace: bool = False):
    method forward (line 73) | def forward(self, x):
  function hard_sigmoid (line 77) | def hard_sigmoid(x, inplace: bool = False):
  class HardSigmoid (line 84) | class HardSigmoid(nn.Module):
    method __init__ (line 85) | def __init__(self, inplace: bool = False):
    method forward (line 89) | def forward(self, x):

FILE: models/activations/activations_autofn.py
  class SwishAutoFn (line 9) | class SwishAutoFn(torch.autograd.Function):
    method forward (line 15) | def forward(ctx, x):
    method backward (line 21) | def backward(ctx, grad_output):
  function swish_auto (line 27) | def swish_auto(x, inplace=False):
  class SwishAuto (line 32) | class SwishAuto(nn.Module):
    method __init__ (line 33) | def __init__(self, inplace: bool = False):
    method forward (line 37) | def forward(self, x):
  class MishAutoFn (line 41) | class MishAutoFn(torch.autograd.Function):
    method forward (line 47) | def forward(ctx, x):
    method backward (line 53) | def backward(ctx, grad_output):
  function mish_auto (line 60) | def mish_auto(x, inplace=False):
  class MishAuto (line 65) | class MishAuto(nn.Module):
    method __init__ (line 66) | def __init__(self, inplace: bool = False):
    method forward (line 70) | def forward(self, x):

FILE: models/activations/activations_jit.py
  function swish_jit_fwd (line 11) | def swish_jit_fwd(x):
  function swish_jit_bwd (line 16) | def swish_jit_bwd(x, grad_output):
  class SwishJitAutoFn (line 21) | class SwishJitAutoFn(torch.autograd.Function):
    method forward (line 27) | def forward(ctx, x):
    method backward (line 32) | def backward(ctx, grad_output):
  function swish_jit (line 37) | def swish_jit(x, inplace=False):
  class SwishJit (line 42) | class SwishJit(nn.Module):
    method __init__ (line 43) | def __init__(self, inplace: bool = False):
    method forward (line 47) | def forward(self, x):
  function mish_jit_fwd (line 52) | def mish_jit_fwd(x):
  function mish_jit_bwd (line 57) | def mish_jit_bwd(x, grad_output):
  class MishJitAutoFn (line 63) | class MishJitAutoFn(torch.autograd.Function):
    method forward (line 65) | def forward(ctx, x):
    method backward (line 70) | def backward(ctx, grad_output):
  function mish_jit (line 75) | def mish_jit(x, inplace=False):
  class MishJit (line 80) | class MishJit(nn.Module):
    method __init__ (line 81) | def __init__(self, inplace: bool = False):
    method forward (line 85) | def forward(self, x):

FILE: models/activations/config.py
  function is_exportable (line 13) | def is_exportable():
  function set_exportable (line 17) | def set_exportable(value):
  function is_scriptable (line 22) | def is_scriptable():
  function set_scriptable (line 26) | def set_scriptable(value):

FILE: models/config.py
  function is_exportable (line 13) | def is_exportable():
  function set_exportable (line 17) | def set_exportable(value):
  function is_scriptable (line 22) | def is_scriptable():
  function set_scriptable (line 26) | def set_scriptable(value):

FILE: models/conv2d_layers.py
  function _ntuple (line 15) | def _ntuple(n):
  function _is_static_pad (line 29) | def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
  function _get_padding (line 33) | def _get_padding(kernel_size, stride=1, dilation=1, **_):
  function _calc_same_pad (line 38) | def _calc_same_pad(i: int, k: int, s: int, d: int):
  function _same_pad_arg (line 42) | def _same_pad_arg(input_size, kernel_size, stride, dilation):
  function _split_channels (line 50) | def _split_channels(num_chan, num_groups):
  function conv2d_same (line 56) | def conv2d_same(
  class Conv2dSame (line 68) | class Conv2dSame(nn.Conv2d):
    method __init__ (line 73) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 78) | def forward(self, x):
  class Conv2dSameExport (line 82) | class Conv2dSameExport(nn.Conv2d):
    method __init__ (line 89) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 96) | def forward(self, x):
  function get_padding_value (line 110) | def get_padding_value(padding, kernel_size, **kwargs):
  function create_conv2d_pad (line 133) | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
  class MixedConv2d (line 147) | class MixedConv2d(nn.ModuleDict):
    method __init__ (line 153) | def __init__(self, in_channels, out_channels, kernel_size=3,
    method forward (line 173) | def forward(self, x):
  function get_condconv_initializer (line 180) | def get_condconv_initializer(initializer, num_experts, expert_shape):
  class CondConv2d (line 193) | class CondConv2d(nn.Module):
    method __init__ (line 202) | def __init__(self, in_channels, out_channels, kernel_size=3,
    method reset_parameters (line 232) | def reset_parameters(self):
    method forward (line 243) | def forward(self, x, routing_weights):
  function select_conv2d (line 284) | def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):

FILE: models/densenet.py
  function densenet121 (line 20) | def densenet121(pretrained=False, **kwargs):
  function densenet169 (line 49) | def densenet169(pretrained=False, **kwargs):
  function densenet201 (line 77) | def densenet201(pretrained=False, **kwargs):
  function densenet161 (line 105) | def densenet161(pretrained=False, **kwargs):
  class _DenseLayer (line 133) | class _DenseLayer(nn.Sequential):
    method __init__ (line 134) | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
    method forward (line 146) | def forward(self, x):
  class _DenseBlock (line 153) | class _DenseBlock(nn.Sequential):
    method __init__ (line 154) | def __init__(self, num_layers, num_input_features, bn_size, growth_rat...
  class _Transition (line 161) | class _Transition(nn.Sequential):
    method __init__ (line 162) | def __init__(self, num_input_features, num_output_features):
  class DenseNet (line 171) | class DenseNet(nn.Module):
    method __init__ (line 184) | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
    method forward (line 228) | def forward(self, x):

FILE: models/efficientnet_builder.py
  function get_bn_args_tf (line 21) | def get_bn_args_tf():
  function resolve_bn_args (line 25) | def resolve_bn_args(kwargs):
  function resolve_se_args (line 43) | def resolve_se_args(kwargs, in_chs, act_layer=None):
  function resolve_act_layer (line 58) | def resolve_act_layer(kwargs, default='relu'):
  function make_divisible (line 65) | def make_divisible(v: int, divisor: int = 8, min_value: int = None):
  function round_channels (line 73) | def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
  function drop_connect (line 81) | def drop_connect(inputs, training: bool = False, drop_connect_rate: floa...
  class SqueezeExcite (line 94) | class SqueezeExcite(nn.Module):
    method __init__ (line 96) | def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_l...
    method forward (line 105) | def forward(self, x):
  class ConvBnAct (line 115) | class ConvBnAct(nn.Module):
    method __init__ (line 116) | def __init__(self, in_chs, out_chs, kernel_size,
    method forward (line 125) | def forward(self, x):
  class DepthwiseSeparableConv (line 132) | class DepthwiseSeparableConv(nn.Module):
    method __init__ (line 137) | def __init__(self, in_chs, out_chs, dw_kernel_size=3,
    method forward (line 163) | def forward(self, x):
  class InvertedResidual (line 183) | class InvertedResidual(nn.Module):
    method __init__ (line 186) | def __init__(self, in_chs, out_chs, dw_kernel_size=3,
    method forward (line 220) | def forward(self, x):
  class CondConvResidual (line 247) | class CondConvResidual(InvertedResidual):
    method __init__ (line 250) | def __init__(self, in_chs, out_chs, dw_kernel_size=3,
    method forward (line 268) | def forward(self, x):
  class EdgeResidual (line 299) | class EdgeResidual(nn.Module):
    method __init__ (line 302) | def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, ...
    method forward (line 327) | def forward(self, x):
  class EfficientNetBuilder (line 350) | class EfficientNetBuilder:
    method __init__ (line 360) | def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_...
    method _round_channels (line 378) | def _round_channels(self, chs):
    method _make_block (line 381) | def _make_block(self, ba):
    method _make_stack (line 416) | def _make_stack(self, stack_args):
    method __call__ (line 428) | def __call__(self, in_chs, block_args):
  function _parse_ksize (line 449) | def _parse_ksize(ss):
  function _decode_block_str (line 456) | def _decode_block_str(block_str):
  function _scale_stage_depth (line 575) | def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_...
  function decode_arch_def (line 613) | def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', ...
  function initialize_weight_goog (line 633) | def initialize_weight_goog(m, n='', fix_group_fanout=True):
  function initialize_weight_default (line 665) | def initialize_weight_default(m, n=''):

FILE: models/gen_efficientnet.py
  class GenEfficientNet (line 203) | class GenEfficientNet(nn.Module):
    method __init__ (line 214) | def __init__(self, block_args, num_classes=1000, in_chans=3, num_featu...
    method features (line 248) | def features(self, x):
    method as_sequential (line 258) | def as_sequential(self):
    method forward (line 266) | def forward(self, x):
  function _create_model (line 275) | def _create_model(model_kwargs, variant, pretrained=False):
  function _gen_mnasnet_a1 (line 285) | def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, *...
  function _gen_mnasnet_b1 (line 322) | def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, *...
  function _gen_mnasnet_small (line 359) | def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False...
  function _gen_fbnetc (line 389) | def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwa...
  function _gen_spnasnet (line 420) | def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **k...
  function _gen_efficientnet (line 456) | def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=...
  function _gen_efficientnet_edge (line 501) | def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multip...
  function _gen_efficientnet_condconv (line 525) | def _gen_efficientnet_condconv(
  function _gen_efficientnet_lite (line 550) | def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multip...
  function _gen_mixnet_s (line 591) | def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **k...
  function _gen_mixnet_m (line 625) | def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,...
  function mnasnet_050 (line 659) | def mnasnet_050(pretrained=False, **kwargs):
  function mnasnet_075 (line 665) | def mnasnet_075(pretrained=False, **kwargs):
  function mnasnet_100 (line 671) | def mnasnet_100(pretrained=False, **kwargs):
  function mnasnet_b1 (line 677) | def mnasnet_b1(pretrained=False, **kwargs):
  function mnasnet_140 (line 682) | def mnasnet_140(pretrained=False, **kwargs):
  function semnasnet_050 (line 688) | def semnasnet_050(pretrained=False, **kwargs):
  function semnasnet_075 (line 694) | def semnasnet_075(pretrained=False, **kwargs):
  function semnasnet_100 (line 700) | def semnasnet_100(pretrained=False, **kwargs):
  function mnasnet_a1 (line 706) | def mnasnet_a1(pretrained=False, **kwargs):
  function semnasnet_140 (line 711) | def semnasnet_140(pretrained=False, **kwargs):
  function mnasnet_small (line 717) | def mnasnet_small(pretrained=False, **kwargs):
  function fbnetc_100 (line 723) | def fbnetc_100(pretrained=False, **kwargs):
  function spnasnet_100 (line 732) | def spnasnet_100(pretrained=False, **kwargs):
  function efficientnet_b0 (line 738) | def efficientnet_b0(pretrained=False, **kwargs):
  function efficientnet_b1 (line 746) | def efficientnet_b1(pretrained=False, **kwargs):
  function efficientnet_b2 (line 754) | def efficientnet_b2(pretrained=False, **kwargs):
  function efficientnet_b3 (line 762) | def efficientnet_b3(pretrained=False, **kwargs):
  function efficientnet_b4 (line 770) | def efficientnet_b4(pretrained=False, **kwargs):
  function efficientnet_b5 (line 778) | def efficientnet_b5(pretrained=False, **kwargs):
  function efficientnet_b6 (line 786) | def efficientnet_b6(pretrained=False, **kwargs):
  function efficientnet_b7 (line 794) | def efficientnet_b7(pretrained=False, **kwargs):
  function efficientnet_b8 (line 802) | def efficientnet_b8(pretrained=False, **kwargs):
  function efficientnet_l2 (line 810) | def efficientnet_l2(pretrained=False, **kwargs):
  function efficientnet_es (line 818) | def efficientnet_es(pretrained=False, **kwargs):
  function efficientnet_em (line 825) | def efficientnet_em(pretrained=False, **kwargs):
  function efficientnet_el (line 832) | def efficientnet_el(pretrained=False, **kwargs):
  function efficientnet_cc_b0_4e (line 839) | def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
  function efficientnet_cc_b0_8e (line 847) | def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
  function efficientnet_cc_b1_8e (line 856) | def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
  function efficientnet_lite0 (line 865) | def efficientnet_lite0(pretrained=False, **kwargs):
  function efficientnet_lite1 (line 872) | def efficientnet_lite1(pretrained=False, **kwargs):
  function efficientnet_lite2 (line 879) | def efficientnet_lite2(pretrained=False, **kwargs):
  function efficientnet_lite3 (line 886) | def efficientnet_lite3(pretrained=False, **kwargs):
  function efficientnet_lite4 (line 893) | def efficientnet_lite4(pretrained=False, **kwargs):
  function tf_efficientnet_b0 (line 900) | def tf_efficientnet_b0(pretrained=False, **kwargs):
  function tf_efficientnet_b1 (line 909) | def tf_efficientnet_b1(pretrained=False, **kwargs):
  function tf_efficientnet_b2 (line 918) | def tf_efficientnet_b2(pretrained=False, **kwargs):
  function tf_efficientnet_b3 (line 927) | def tf_efficientnet_b3(pretrained=False, **kwargs):
  function tf_efficientnet_b4 (line 936) | def tf_efficientnet_b4(pretrained=False, **kwargs):
  function tf_efficientnet_b5 (line 945) | def tf_efficientnet_b5(pretrained=False, **kwargs):
  function tf_efficientnet_b6 (line 954) | def tf_efficientnet_b6(pretrained=False, **kwargs):
  function tf_efficientnet_b7 (line 963) | def tf_efficientnet_b7(pretrained=False, **kwargs):
  function tf_efficientnet_b8 (line 972) | def tf_efficientnet_b8(pretrained=False, **kwargs):
  function tf_efficientnet_b0_ap (line 981) | def tf_efficientnet_b0_ap(pretrained=False, **kwargs):
  function tf_efficientnet_b1_ap (line 992) | def tf_efficientnet_b1_ap(pretrained=False, **kwargs):
  function tf_efficientnet_b2_ap (line 1003) | def tf_efficientnet_b2_ap(pretrained=False, **kwargs):
  function tf_efficientnet_b3_ap (line 1014) | def tf_efficientnet_b3_ap(pretrained=False, **kwargs):
  function tf_efficientnet_b4_ap (line 1025) | def tf_efficientnet_b4_ap(pretrained=False, **kwargs):
  function tf_efficientnet_b5_ap (line 1036) | def tf_efficientnet_b5_ap(pretrained=False, **kwargs):
  function tf_efficientnet_b6_ap (line 1047) | def tf_efficientnet_b6_ap(pretrained=False, **kwargs):
  function tf_efficientnet_b7_ap (line 1059) | def tf_efficientnet_b7_ap(pretrained=False, **kwargs):
  function tf_efficientnet_b8_ap (line 1071) | def tf_efficientnet_b8_ap(pretrained=False, **kwargs):
  function tf_efficientnet_b0_ns (line 1083) | def tf_efficientnet_b0_ns(pretrained=False, **kwargs):
  function tf_efficientnet_b1_ns (line 1094) | def tf_efficientnet_b1_ns(pretrained=False, **kwargs):
  function tf_efficientnet_b2_ns (line 1105) | def tf_efficientnet_b2_ns(pretrained=False, **kwargs):
  function tf_efficientnet_b3_ns (line 1116) | def tf_efficientnet_b3_ns(pretrained=False, **kwargs):
  function tf_efficientnet_b4_ns (line 1127) | def tf_efficientnet_b4_ns(pretrained=False, **kwargs):
  function tf_efficientnet_b5_ns (line 1138) | def tf_efficientnet_b5_ns(pretrained=False, **kwargs):
  function tf_efficientnet_b6_ns (line 1149) | def tf_efficientnet_b6_ns(pretrained=False, **kwargs):
  function tf_efficientnet_b7_ns (line 1161) | def tf_efficientnet_b7_ns(pretrained=False, **kwargs):
  function tf_efficientnet_l2_ns_475 (line 1173) | def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs):
  function tf_efficientnet_l2_ns (line 1185) | def tf_efficientnet_l2_ns(pretrained=False, **kwargs):
  function tf_efficientnet_es (line 1197) | def tf_efficientnet_es(pretrained=False, **kwargs):
  function tf_efficientnet_em (line 1206) | def tf_efficientnet_em(pretrained=False, **kwargs):
  function tf_efficientnet_el (line 1215) | def tf_efficientnet_el(pretrained=False, **kwargs):
  function tf_efficientnet_cc_b0_4e (line 1224) | def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
  function tf_efficientnet_cc_b0_8e (line 1233) | def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
  function tf_efficientnet_cc_b1_8e (line 1243) | def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
  function tf_efficientnet_lite0 (line 1253) | def tf_efficientnet_lite0(pretrained=False, **kwargs):
  function tf_efficientnet_lite1 (line 1262) | def tf_efficientnet_lite1(pretrained=False, **kwargs):
  function tf_efficientnet_lite2 (line 1271) | def tf_efficientnet_lite2(pretrained=False, **kwargs):
  function tf_efficientnet_lite3 (line 1280) | def tf_efficientnet_lite3(pretrained=False, **kwargs):
  function tf_efficientnet_lite4 (line 1289) | def tf_efficientnet_lite4(pretrained=False, **kwargs):
  function mixnet_s (line 1298) | def mixnet_s(pretrained=False, **kwargs):
  function mixnet_m (line 1307) | def mixnet_m(pretrained=False, **kwargs):
  function mixnet_l (line 1316) | def mixnet_l(pretrained=False, **kwargs):
  function mixnet_xl (line 1325) | def mixnet_xl(pretrained=False, **kwargs):
  function mixnet_xxl (line 1335) | def mixnet_xxl(pretrained=False, **kwargs):
  function tf_mixnet_s (line 1345) | def tf_mixnet_s(pretrained=False, **kwargs):
  function tf_mixnet_m (line 1355) | def tf_mixnet_m(pretrained=False, **kwargs):
  function tf_mixnet_l (line 1365) | def tf_mixnet_l(pretrained=False, **kwargs):

FILE: models/helpers.py
  function load_checkpoint (line 10) | def load_checkpoint(model, checkpoint_path):
  function load_pretrained (line 31) | def load_pretrained(model, url, filter_fn=None, strict=True):

FILE: models/mobilenetv3.py
  class MobileNetV3 (line 45) | class MobileNetV3(nn.Module):
    method __init__ (line 54) | def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size...
    method as_sequential (line 86) | def as_sequential(self):
    method features (line 94) | def features(self, x):
    method forward (line 104) | def forward(self, x):
  function _create_model (line 112) | def _create_model(model_kwargs, variant, pretrained=False):
  function _gen_mobilenet_v3_rw (line 122) | def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=Fal...
  function _gen_mobilenet_v3 (line 170) | def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False,...
  function mobilenetv3_rw (line 266) | def mobilenetv3_rw(pretrained=False, **kwargs):
  function mobilenetv3_large_075 (line 278) | def mobilenetv3_large_075(pretrained=False, **kwargs):
  function mobilenetv3_large_100 (line 285) | def mobilenetv3_large_100(pretrained=False, **kwargs):
  function mobilenetv3_large_125 (line 292) | def mobilenetv3_large_125(pretrained=False, **kwargs):
  function mobilenetv3_large_minimal_100 (line 300) | def mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
  function mobilenetv3_small_075 (line 307) | def mobilenetv3_small_075(pretrained=False, **kwargs):
  function mobilenetv3_small_100 (line 313) | def mobilenetv3_small_100(pretrained=False, **kwargs):
  function mobilenetv3_small_minimal_100 (line 319) | def mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
  function tf_mobilenetv3_large_075 (line 325) | def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
  function tf_mobilenetv3_large_100 (line 333) | def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
  function tf_mobilenetv3_large_minimal_100 (line 341) | def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
  function tf_mobilenetv3_small_075 (line 349) | def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
  function tf_mobilenetv3_small_100 (line 357) | def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
  function tf_mobilenetv3_small_minimal_100 (line 365) | def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):

FILE: models/model_factory.py
  function create_model (line 6) | def create_model(

FILE: models/resnet.py
  function conv3x3 (line 17) | def conv3x3(in_planes, out_planes, stride=1, groups=1):
  function conv1x1 (line 23) | def conv1x1(in_planes, out_planes, stride=1):
  class BasicBlock (line 28) | class BasicBlock(nn.Module):
    method __init__ (line 31) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
    method forward (line 47) | def forward(self, x):
  class Bottleneck (line 66) | class Bottleneck(nn.Module):
    method __init__ (line 69) | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
    method forward (line 86) | def forward(self, x):
  class ResNet (line 109) | class ResNet(nn.Module):
    method __init__ (line 111) | def __init__(self, block, layers, num_classes=1000, zero_init_residual...
    method _make_layer (line 149) | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
    method forward (line 169) | def forward(self, x):
  function resnet18 (line 185) | def resnet18(pretrained=False, **kwargs):
  function resnet34 (line 197) | def resnet34(pretrained=False, **kwargs):
  function resnet50 (line 208) | def resnet50(pretrained=False, **kwargs):
  function resnet101 (line 220) | def resnet101(pretrained=False, **kwargs):
  function resnet152 (line 231) | def resnet152(pretrained=False, **kwargs):
  function resnext50_32x4d (line 242) | def resnext50_32x4d(pretrained=False, **kwargs):
  function resnext101_32x8d (line 249) | def resnext101_32x8d(pretrained=False, **kwargs):
  function count_parameters (line 256) | def count_parameters(model):

FILE: network.py
  class Memory (line 8) | class Memory:
    method __init__ (line 9) | def __init__(self):
    method clear_memory (line 17) | def clear_memory(self):
  class ActorCritic (line 26) | class ActorCritic(nn.Module):
    method __init__ (line 27) | def __init__(self, feature_dim, state_dim, hidden_state_dim=1024, poli...
    method forward (line 64) | def forward(self):
    method act (line 67) | def act(self, state_ini, memory, restart_batch=False, training=False):
    method evaluate (line 100) | def evaluate(self, state, action):
  class PPO (line 131) | class PPO:
    method __init__ (line 132) | def __init__(self, feature_dim, state_dim, hidden_state_dim, policy_conv,
    method select_action (line 149) | def select_action(self, state, memory, restart_batch=False, training=T...
    method update (line 152) | def update(self, memory):
  class Full_layer (line 186) | class Full_layer(torch.nn.Module):
    method __init__ (line 187) | def __init__(self, feature_num, hidden_state_dim=1024, fc_rnn=True, cl...
    method forward (line 207) | def forward(self, x, restart=False):

FILE: pycls/core/config.py
  function assert_and_infer_cfg (line 357) | def assert_and_infer_cfg(cache_urls=True):
  function cache_cfg_urls (line 392) | def cache_cfg_urls():
  function dump_cfg (line 400) | def dump_cfg():
  function load_cfg (line 407) | def load_cfg(out_dir, cfg_dest="config.yaml"):

FILE: pycls/core/losses.py
  function get_loss_fun (line 18) | def get_loss_fun():
  function register_loss_fun (line 26) | def register_loss_fun(name, ctor):

FILE: pycls/core/model_builder.py
  function build_model (line 25) | def build_model():
  function register_model (line 50) | def register_model(name, ctor):

FILE: pycls/core/old_config.py
  function assert_and_infer_cfg (line 357) | def assert_and_infer_cfg(cache_urls=True):
  function cache_cfg_urls (line 392) | def cache_cfg_urls():
  function dump_cfg (line 400) | def dump_cfg():
  function load_cfg (line 407) | def load_cfg(out_dir, cfg_dest="config.yaml"):

FILE: pycls/core/optimizer.py
  function construct_optimizer (line 15) | def construct_optimizer(model):
  function get_epoch_lr (line 71) | def get_epoch_lr(cur_epoch):
  function set_lr (line 76) | def set_lr(optimizer, new_lr):

FILE: pycls/datasets/cifar10.py
  class Cifar10 (line 28) | class Cifar10(torch.utils.data.Dataset):
    method __init__ (line 31) | def __init__(self, data_path, split):
    method _load_batch (line 44) | def _load_batch(self, batch_path):
    method _load_data (line 49) | def _load_data(self):
    method _prepare_im (line 69) | def _prepare_im(self, im):
    method __getitem__ (line 77) | def __getitem__(self, index):
    method __len__ (line 82) | def __len__(self):

FILE: pycls/datasets/imagenet.py
  class ImageNet (line 35) | class ImageNet(torch.utils.data.Dataset):
    method __init__ (line 38) | def __init__(self, data_path, split):
    method _construct_imdb (line 49) | def _construct_imdb(self):
    method _prepare_im (line 72) | def _prepare_im(self, im):
    method __getitem__ (line 97) | def __getitem__(self, index):
    method __len__ (line 107) | def __len__(self):

FILE: pycls/datasets/loader.py
  function _construct_loader (line 23) | def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last):
  function construct_train_loader (line 50) | def construct_train_loader():
  function construct_test_loader (line 61) | def construct_test_loader():
  function shuffle (line 72) | def shuffle(loader, cur_epoch):

FILE: pycls/datasets/paths.py
  function has_data_path (line 23) | def has_data_path(dataset_name):
  function get_data_path (line 28) | def get_data_path(dataset_name):
  function register_path (line 33) | def register_path(name, path):

FILE: pycls/datasets/transforms.py
  function color_norm (line 16) | def color_norm(im, mean, std):
  function zero_pad (line 24) | def zero_pad(im, pad_size):
  function horizontal_flip (line 30) | def horizontal_flip(im, p, order="CHW"):
  function random_crop (line 41) | def random_crop(im, size, pad_size=0):
  function scale (line 53) | def scale(size, im):
  function center_crop (line 67) | def center_crop(size, im):
  function random_sized_crop (line 77) | def random_sized_crop(im, size, area_frac=0.08, max_iter=10):
  function lighting (line 98) | def lighting(im, alpha_std, eig_val, eig_vec):

FILE: pycls/models/anynet.py
  function get_stem_fun (line 19) | def get_stem_fun(stem_type):
  function get_block_fun (line 32) | def get_block_fun(block_type):
  class AnyHead (line 45) | class AnyHead(nn.Module):
    method __init__ (line 48) | def __init__(self, w_in, nc):
    method forward (line 53) | def forward(self, x):
  class VanillaBlock (line 61) | class VanillaBlock(nn.Module):
    method __init__ (line 64) | def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
    method _construct (line 71) | def _construct(self, w_in, w_out, stride):
    method forward (line 83) | def forward(self, x):
  class BasicTransform (line 89) | class BasicTransform(nn.Module):
    method __init__ (line 92) | def __init__(self, w_in, w_out, stride):
    method _construct (line 96) | def _construct(self, w_in, w_out, stride):
    method forward (line 108) | def forward(self, x):
  class ResBasicBlock (line 114) | class ResBasicBlock(nn.Module):
    method __init__ (line 117) | def __init__(self, w_in, w_out, stride, bm=None, gw=None, se_r=None):
    method _add_skip_proj (line 124) | def _add_skip_proj(self, w_in, w_out, stride):
    method _construct (line 130) | def _construct(self, w_in, w_out, stride):
    method forward (line 138) | def forward(self, x):
  class SE (line 147) | class SE(nn.Module):
    method __init__ (line 150) | def __init__(self, w_in, w_se):
    method _construct (line 154) | def _construct(self, w_in, w_se):
    method forward (line 165) | def forward(self, x):
  class BottleneckTransform (line 169) | class BottleneckTransform(nn.Module):
    method __init__ (line 172) | def __init__(self, w_in, w_out, stride, bm, gw, se_r):
    method _construct (line 176) | def _construct(self, w_in, w_out, stride, bm, gw, se_r):
    method forward (line 200) | def forward(self, x):
  class ResBottleneckBlock (line 206) | class ResBottleneckBlock(nn.Module):
    method __init__ (line 209) | def __init__(self, w_in, w_out, stride, bm=1.0, gw=1, se_r=None):
    method _add_skip_proj (line 213) | def _add_skip_proj(self, w_in, w_out, stride):
    method _construct (line 219) | def _construct(self, w_in, w_out, stride, bm, gw, se_r):
    method forward (line 227) | def forward(self, x):
  class ResStemCifar (line 236) | class ResStemCifar(nn.Module):
    method __init__ (line 239) | def __init__(self, w_in, w_out):
    method _construct (line 243) | def _construct(self, w_in, w_out):
    method forward (line 251) | def forward(self, x):
  class ResStemIN (line 257) | class ResStemIN(nn.Module):
    method __init__ (line 260) | def __init__(self, w_in, w_out):
    method _construct (line 264) | def _construct(self, w_in, w_out):
    method forward (line 273) | def forward(self, x):
  class SimpleStemIN (line 279) | class SimpleStemIN(nn.Module):
    method __init__ (line 282) | def __init__(self, in_w, out_w):
    method _construct (line 286) | def _construct(self, in_w, out_w):
    method forward (line 294) | def forward(self, x):
  class AnyStage (line 300) | class AnyStage(nn.Module):
    method __init__ (line 303) | def __init__(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
    method _construct (line 307) | def _construct(self, w_in, w_out, stride, d, block_fun, bm, gw, se_r):
    method forward (line 318) | def forward(self, x):
  class AnyNet (line 324) | class AnyNet(nn.Module):
    method __init__ (line 327) | def __init__(self, **kwargs):
    method _construct (line 357) | def _construct(self, stem_type, stem_w, block_type, ds, ws, ss, bms, g...
    method forward (line 380) | def forward(self, x):

FILE: pycls/models/effnet.py
  class EffHead (line 20) | class EffHead(nn.Module):
    method __init__ (line 23) | def __init__(self, w_in, w_out, nc):
    method _construct (line 27) | def _construct(self, w_in, w_out, nc):
    method forward (line 42) | def forward(self, x):
  class Swish (line 51) | class Swish(nn.Module):
    method __init__ (line 54) | def __init__(self):
    method forward (line 57) | def forward(self, x):
  class SE (line 61) | class SE(nn.Module):
    method __init__ (line 64) | def __init__(self, w_in, w_se):
    method _construct (line 68) | def _construct(self, w_in, w_se):
    method forward (line 79) | def forward(self, x):
  class MBConv (line 83) | class MBConv(nn.Module):
    method __init__ (line 86) | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
    method _construct (line 90) | def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out):
    method forward (line 126) | def forward(self, x):
  class EffStage (line 146) | class EffStage(nn.Module):
    method __init__ (line 149) | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
    method _construct (line 153) | def _construct(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
    method forward (line 165) | def forward(self, x):
  class StemIN (line 171) | class StemIN(nn.Module):
    method __init__ (line 174) | def __init__(self, w_in, w_out):
    method _construct (line 178) | def _construct(self, w_in, w_out):
    method forward (line 186) | def forward(self, x):
  class EffNet (line 192) | class EffNet(nn.Module):
    method __init__ (line 195) | def __init__(self):
    method _construct (line 216) | def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
    method forward (line 232) | def forward(self, x):

FILE: pycls/models/regnet.py
  function quantize_float (line 19) | def quantize_float(f, q):
  function adjust_ws_gs_comp (line 24) | def adjust_ws_gs_comp(ws, bms, gs):
  function get_stages_from_blocks (line 33) | def get_stages_from_blocks(ws, rs):
  function generate_regnet (line 42) | def generate_regnet(w_a, w_0, w_m, d, q=8):
  class RegNet (line 54) | class RegNet(AnyNet):
    method __init__ (line 57) | def __init__(self):

FILE: pycls/models/resnet.py
  function get_trans_fun (line 23) | def get_trans_fun(name):
  class ResHead (line 35) | class ResHead(nn.Module):
    method __init__ (line 38) | def __init__(self, w_in, nc):
    method forward (line 43) | def forward(self, x):
  class BasicTransform (line 50) | class BasicTransform(nn.Module):
    method __init__ (line 53) | def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
    method _construct (line 60) | def _construct(self, w_in, w_out, stride):
    method forward (line 72) | def forward(self, x):
  class BottleneckTransform (line 78) | class BottleneckTransform(nn.Module):
    method __init__ (line 81) | def __init__(self, w_in, w_out, stride, w_b, num_gs):
    method _construct (line 85) | def _construct(self, w_in, w_out, stride, w_b, num_gs):
    method forward (line 105) | def forward(self, x):
  class ResBlock (line 111) | class ResBlock(nn.Module):
    method __init__ (line 114) | def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
    method _add_skip_proj (line 118) | def _add_skip_proj(self, w_in, w_out, stride):
    method _construct (line 124) | def _construct(self, w_in, w_out, stride, trans_fun, w_b, num_gs):
    method forward (line 132) | def forward(self, x):
  class ResStage (line 141) | class ResStage(nn.Module):
    method __init__ (line 144) | def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
    method _construct (line 148) | def _construct(self, w_in, w_out, stride, d, w_b, num_gs):
    method forward (line 160) | def forward(self, x):
  class ResStem (line 166) | class ResStem(nn.Module):
    method __init__ (line 169) | def __init__(self, w_in, w_out):
    method _construct_cifar (line 179) | def _construct_cifar(self, w_in, w_out):
    method _construct_imagenet (line 187) | def _construct_imagenet(self, w_in, w_out):
    method forward (line 196) | def forward(self, x):
  class ResNet (line 202) | class ResNet(nn.Module):
    method __init__ (line 205) | def __init__(self):
    method _construct_cifar (line 221) | def _construct_cifar(self):
    method _construct_imagenet (line 239) | def _construct_imagenet(self):
    method forward (line 272) | def forward(self, x):

FILE: pycls/utils/benchmark.py
  function compute_fw_test_time (line 17) | def compute_fw_test_time(model, inputs):
  function compute_fw_bw_time (line 38) | def compute_fw_bw_time(model, loss_fun, inputs, labels):
  function compute_precise_time (line 69) | def compute_precise_time(model, loss_fun):

FILE: pycls/utils/checkpoint.py
  function get_checkpoint_dir (line 23) | def get_checkpoint_dir():
  function get_checkpoint (line 28) | def get_checkpoint(epoch):
  function get_last_checkpoint (line 34) | def get_last_checkpoint():
  function has_checkpoint (line 43) | def has_checkpoint():
  function is_checkpoint_epoch (line 51) | def is_checkpoint_epoch(cur_epoch):
  function save_checkpoint (line 56) | def save_checkpoint(model, optimizer, epoch):
  function load_checkpoint (line 78) | def load_checkpoint(checkpoint_file, model, optimizer=None):

FILE: pycls/utils/distributed.py
  function is_master_proc (line 14) | def is_master_proc():
  function init_process_group (line 25) | def init_process_group(proc_rank, world_size):
  function destroy_process_group (line 38) | def destroy_process_group():
  function scaled_all_reduce (line 43) | def scaled_all_reduce(tensors):

FILE: pycls/utils/error_handler.py
  class ChildException (line 15) | class ChildException(Exception):
    method __init__ (line 18) | def __init__(self, child_trace):
  class ErrorHandler (line 22) | class ErrorHandler(object):
    method __init__ (line 29) | def __init__(self, error_queue):
    method add_child (line 40) | def add_child(self, pid):
    method listen (line 44) | def listen(self):
    method signal_handler (line 53) | def signal_handler(self, _sig_num, _stack_frame):

FILE: pycls/utils/io.py
  function cache_url (line 22) | def cache_url(url_or_file, cache_dir):
  function _progress_bar (line 50) | def _progress_bar(count, total):
  function download_url (line 69) | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_pro...

FILE: pycls/utils/logging.py
  function _suppress_print (line 31) | def _suppress_print():
  function setup_logging (line 40) | def setup_logging():
  function get_logger (line 60) | def get_logger(name):
  function log_json_stats (line 65) | def log_json_stats(stats):
  function load_json_stats (line 77) | def load_json_stats(log_file):
  function parse_json_stats (line 86) | def parse_json_stats(log, row_type, key):
  function get_log_files (line 94) | def get_log_files(log_dir, name_filter=""):

FILE: pycls/utils/lr_policy.py
  function lr_fun_steps (line 14) | def lr_fun_steps(cur_epoch):
  function lr_fun_exp (line 20) | def lr_fun_exp(cur_epoch):
  function lr_fun_cos (line 25) | def lr_fun_cos(cur_epoch):
  function get_lr_fun (line 31) | def get_lr_fun():
  function get_epoch_lr (line 39) | def get_epoch_lr(cur_epoch):

FILE: pycls/utils/meters.py
  function eta_str (line 20) | def eta_str(eta_td):
  class ScalarMeter (line 28) | class ScalarMeter(object):
    method __init__ (line 31) | def __init__(self, window_size):
    method reset (line 36) | def reset(self):
    method add_value (line 41) | def add_value(self, value):
    method get_win_median (line 46) | def get_win_median(self):
    method get_win_avg (line 49) | def get_win_avg(self):
    method get_global_avg (line 52) | def get_global_avg(self):
  class TrainMeter (line 56) | class TrainMeter(object):
    method __init__ (line 59) | def __init__(self, epoch_iters):
    method reset (line 74) | def reset(self, timer=False):
    method iter_tic (line 86) | def iter_tic(self):
    method iter_toc (line 89) | def iter_toc(self):
    method update_stats (line 92) | def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
    method get_iter_stats (line 104) | def get_iter_stats(self, cur_epoch, cur_iter):
    method log_iter_stats (line 125) | def log_iter_stats(self, cur_epoch, cur_iter):
    method get_epoch_stats (line 131) | def get_epoch_stats(self, cur_epoch):
    method log_epoch_stats (line 153) | def log_epoch_stats(self, cur_epoch):
  class TestMeter (line 158) | class TestMeter(object):
    method __init__ (line 161) | def __init__(self, max_iter):
    method reset (line 175) | def reset(self, min_errs=False):
    method iter_tic (line 186) | def iter_tic(self):
    method iter_toc (line 189) | def iter_toc(self):
    method update_stats (line 192) | def update_stats(self, top1_err, top5_err, mb_size):
    method get_iter_stats (line 199) | def get_iter_stats(self, cur_epoch, cur_iter):
    method log_iter_stats (line 213) | def log_iter_stats(self, cur_epoch, cur_iter):
    method get_epoch_stats (line 219) | def get_epoch_stats(self, cur_epoch):
    method log_epoch_stats (line 237) | def log_epoch_stats(self, cur_epoch):

FILE: pycls/utils/metrics.py
  function topks_correct (line 20) | def topks_correct(preds, labels, ks):
  function topk_errors (line 40) | def topk_errors(preds, labels, ks):
  function topk_accuracies (line 46) | def topk_accuracies(preds, labels, ks):
  function params_count (line 52) | def params_count(model):
  function flops_count (line 57) | def flops_count(model):
  function acts_count (line 79) | def acts_count(model):
  function gpu_mem_usage (line 101) | def gpu_mem_usage():

FILE: pycls/utils/multiprocessing.py
  function run (line 17) | def run(proc_rank, world_size, error_queue, fun, fun_args, fun_kwargs):
  function multi_proc_run (line 35) | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):

FILE: pycls/utils/net.py
  function init_weights (line 18) | def init_weights(m):
  function compute_precise_bn_stats (line 36) | def compute_precise_bn_stats(model, loader):
  function reset_bn_stats (line 65) | def reset_bn_stats(model):
  function drop_connect (line 72) | def drop_connect(x, drop_ratio):
  function get_flat_weights (line 82) | def get_flat_weights(model):
  function set_flat_weights (line 87) | def set_flat_weights(model, flat_weights):

FILE: pycls/utils/plotting.py
  function get_plot_colors (line 17) | def get_plot_colors(max_colors, color_format="pyplot"):
  function prepare_plot_data (line 27) | def prepare_plot_data(log_files, names, key="top1_err"):
  function plot_error_curves_plotly (line 42) | def plot_error_curves_plotly(log_files, names, filename, key="top1_err"):
  function plot_error_curves_pyplot (line 115) | def plot_error_curves_pyplot(log_files, names, filename=None, key="top1_...

FILE: pycls/utils/timer.py
  class Timer (line 13) | class Timer(object):
    method __init__ (line 16) | def __init__(self):
    method tic (line 19) | def tic(self):
    method toc (line 24) | def toc(self):
    method reset (line 30) | def reset(self):

FILE: simplejson/__init__.py
  function _import_OrderedDict (line 136) | def _import_OrderedDict():
  function _import_c_make_encoder (line 145) | def _import_c_make_encoder():
  function dump (line 172) | def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True,
  function dumps (line 294) | def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True,
  function load (line 419) | def load(fp, encoding=None, cls=None, object_hook=None, parse_float=None,
  function loads (line 474) | def loads(s, encoding=None, cls=None, object_hook=None, parse_float=None,
  function _toggle_speedups (line 545) | def _toggle_speedups(enabled):
  function simple_first (line 580) | def simple_first(kv):

FILE: simplejson/_speedups.c
  function json_PyOS_string_to_double (line 33) | static double
  type JSON_Accu (line 86) | typedef struct {
  type PyScannerObject (line 112) | typedef struct _PyScannerObject {
  type PyEncoderObject (line 136) | typedef struct _PyEncoderObject {
  function is_raw_json (line 261) | static int
  function JSON_Accu_Init (line 267) | static int
  function flush_accumulator (line 278) | static int
  function JSON_Accu_Accumulate (line 308) | static int
  function PyObject (line 333) | static PyObject *
  function JSON_Accu_Destroy (line 352) | static void
  function IS_DIGIT (line 359) | static int
  function PyObject (line 365) | static PyObject *
  function _is_namedtuple (line 385) | static int
  function _has_for_json_hook (line 399) | static int
  function _convertPyInt_AsSsize_t (line 413) | static int
  function PyObject (line 423) | static PyObject *
  function Py_ssize_t (line 430) | static Py_ssize_t
  function Py_ssize_t (line 474) | static Py_ssize_t
  function PyObject (line 499) | static PyObject *
  function PyObject (line 542) | static PyObject *
  function PyObject (line 556) | static PyObject *
  function PyObject (line 605) | static PyObject *
  function PyObject (line 666) | static PyObject *
  function raise_errmsg (line 765) | static void
  function PyObject (line 775) | static PyObject *
  function PyObject (line 785) | static PyObject *
  function PyObject (line 799) | static PyObject *
  function PyObject (line 843) | static PyObject *
  function PyObject (line 1050) | static PyObject *
  function PyObject (line 1247) | static PyObject *
  function PyObject (line 1289) | static PyObject *
  function scanner_dealloc (line 1310) | static void
  function scanner_traverse (line 1319) | static int
  function scanner_clear (line 1336) | static int
  function PyObject (line 1354) | static PyObject *
  function PyObject (line 1515) | static PyObject *
  function PyObject (line 1679) | static PyObject *
  function PyObject (line 1759) | static PyObject *
  function PyObject (line 1839) | static PyObject *
  function PyObject (line 1861) | static PyObject *
  function PyObject (line 1966) | static PyObject *
  function PyObject (line 2078) | static PyObject *
  function PyObject (line 2185) | static PyObject *
  function PyObject (line 2317) | static PyObject *
  function PyObject (line 2352) | static PyObject *
  function PyObject (line 2377) | static PyObject *
  function PyObject (line 2482) | static PyObject *
  function PyObject (line 2658) | static PyObject *
  function PyObject (line 2681) | static PyObject *
  function PyObject (line 2715) | static PyObject *
  function PyObject (line 2759) | static PyObject *
  function _steal_accumulate (line 2784) | static int
  function encoder_listencode_obj (line 2793) | static int
  function encoder_listencode_dict (line 2941) | static int
  function encoder_listencode_list (line 3080) | static int
  function encoder_dealloc (line 3174) | static void
  function encoder_traverse (line 3183) | static int
  function encoder_clear (line 3206) | static int
  type PyModuleDef (line 3292) | struct PyModuleDef
  function PyObject (line 3305) | PyObject *
  function init_constants (line 3317) | static int
  function PyObject (line 3343) | static PyObject *
  function PyMODINIT_FUNC (line 3373) | PyMODINIT_FUNC
  function init_speedups (line 3379) | void

FILE: simplejson/compat.py
  function b (line 6) | def b(s):
  function b (line 25) | def b(s):

FILE: simplejson/decoder.py
  function _import_c_scanstring (line 10) | def _import_c_scanstring():
  function _floatconstants (line 24) | def _floatconstants():
  function py_scanstring (line 49) | def py_scanstring(s, end, encoding=None, strict=True,
  function JSONObject (line 142) | def JSONObject(state, encoding, strict, scan_once, object_hook,
  function JSONArray (line 236) | def JSONArray(state, scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR):
  class JSONDecoder (line 272) | class JSONDecoder(object):
    method __init__ (line 302) | def __init__(self, encoding=None, object_hook=None, parse_float=None,
    method decode (line 363) | def decode(self, s, _w=WHITESPACE.match, _PY3=PY3):
    method raw_decode (line 376) | def raw_decode(self, s, idx=0, _w=WHITESPACE.match, _PY3=PY3):

FILE: simplejson/encoder.py
  function _import_speedups (line 9) | def _import_speedups():
  function encode_basestring (line 38) | def encode_basestring(s, _PY3=PY3, _q=u'"'):
  function py_encode_basestring_ascii (line 65) | def py_encode_basestring_ascii(s, _PY3=PY3):
  class JSONEncoder (line 109) | class JSONEncoder(object):
    method __init__ (line 141) | def __init__(self, skipkeys=False, ensure_ascii=True,
    method default (line 254) | def default(self, o):
    method encode (line 275) | def encode(self, o):
    method iterencode (line 304) | def iterencode(self, o, _one_shot=False):
  class JSONEncoderForHTML (line 383) | class JSONEncoderForHTML(JSONEncoder):
    method encode (line 397) | def encode(self, o):
    method iterencode (line 406) | def iterencode(self, o, _one_shot=False):
  function _make_iterencode (line 420) | def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,

FILE: simplejson/errors.py
  function linecol (line 6) | def linecol(doc, pos):
  function errmsg (line 15) | def errmsg(msg, doc, pos, end=None):
  class JSONDecodeError (line 26) | class JSONDecodeError(ValueError):
    method __init__ (line 40) | def __init__(self, msg, doc, pos, end=None):
    method __reduce__ (line 52) | def __reduce__(self):

FILE: simplejson/ordered_dict.py
  class OrderedDict (line 8) | class OrderedDict(dict, DictMixin):
    method __init__ (line 10) | def __init__(self, *args, **kwds):
    method clear (line 19) | def clear(self):
    method __setitem__ (line 25) | def __setitem__(self, key, value):
    method __delitem__ (line 32) | def __delitem__(self, key):
    method __iter__ (line 38) | def __iter__(self):
    method __reversed__ (line 45) | def __reversed__(self):
    method popitem (line 52) | def popitem(self, last=True):
    method __reduce__ (line 59) | def __reduce__(self):
    method keys (line 69) | def keys(self):
    method __repr__ (line 81) | def __repr__(self):
    method copy (line 86) | def copy(self):
    method fromkeys (line 90) | def fromkeys(cls, iterable, value=None):
    method __eq__ (line 96) | def __eq__(self, other):
    method __ne__ (line 102) | def __ne__(self, other):

FILE: simplejson/raw_json.py
  class RawJSON (line 4) | class RawJSON(object):
    method __init__ (line 8) | def __init__(self, encoded_json):

FILE: simplejson/scanner.py
  function _import_c_make_scanner (line 5) | def _import_c_make_scanner():
  function py_make_scanner (line 20) | def py_make_scanner(context):

FILE: simplejson/tests/__init__.py
  class NoExtensionTestSuite (line 7) | class NoExtensionTestSuite(unittest.TestSuite):
    method run (line 8) | def run(self, result):
  class TestMissingSpeedups (line 16) | class TestMissingSpeedups(unittest.TestCase):
    method runTest (line 17) | def runTest(self):
  function additional_tests (line 24) | def additional_tests(suite=None):
  function all_tests_suite (line 43) | def all_tests_suite():
  function main (line 64) | def main():

FILE: simplejson/tests/test_bigint_as_string.py
  class TestBigintAsString (line 6) | class TestBigintAsString(TestCase):
    method test_ints (line 23) | def test_ints(self):
    method test_lists (line 33) | def test_lists(self):
    method test_dicts (line 45) | def test_dicts(self):
    method test_dict_keys (line 57) | def test_dict_keys(self):

FILE: simplejson/tests/test_bitsize_int_as_string.py
  class TestBitSizeIntAsString (line 6) | class TestBitSizeIntAsString(TestCase):
    method test_invalid_counts (line 20) | def test_invalid_counts(self):
    method test_ints_outside_range_fails (line 26) | def test_ints_outside_range_fails(self):
    method test_ints (line 32) | def test_ints(self):
    method test_lists (line 42) | def test_lists(self):
    method test_dicts (line 53) | def test_dicts(self):
    method test_dict_keys (line 64) | def test_dict_keys(self):

FILE: simplejson/tests/test_check_circular.py
  function default_iterable (line 4) | def default_iterable(obj):
  class TestCheckCircular (line 7) | class TestCheckCircular(TestCase):
    method test_circular_dict (line 8) | def test_circular_dict(self):
    method test_circular_list (line 13) | def test_circular_list(self):
    method test_circular_composite (line 18) | def test_circular_composite(self):
    method test_circular_default (line 24) | def test_circular_default(self):
    method test_circular_off_default (line 28) | def test_circular_off_default(self):

FILE: simplejson/tests/test_decimal.py
  class TestDecimal (line 8) | class TestDecimal(TestCase):
    method dumps (line 10) | def dumps(self, obj, **kw):
    method loads (line 17) | def loads(self, s, **kw):
    method test_decimal_encode (line 23) | def test_decimal_encode(self):
    method test_decimal_decode (line 27) | def test_decimal_decode(self):
    method test_stringify_key (line 31) | def test_stringify_key(self):
    method test_decimal_roundtrip (line 39) | def test_decimal_roundtrip(self):
    method test_decimal_defaults (line 49) | def test_decimal_defaults(self):
    method test_decimal_reload (line 64) | def test_decimal_reload(self):

FILE: simplejson/tests/test_decode.py
  class MisbehavingBytesSubtype (line 9) | class MisbehavingBytesSubtype(binary_type):
    method decode (line 10) | def decode(self, encoding=None):
    method __str__ (line 12) | def __str__(self):
    method __bytes__ (line 14) | def __bytes__(self):
  class TestDecode (line 17) | class TestDecode(TestCase):
    method assertIs (line 19) | def assertIs(self, a, b):
    method test_decimal (line 22) | def test_decimal(self):
    method test_float (line 27) | def test_float(self):
    method test_decoder_optimizations (line 32) | def test_decoder_optimizations(self):
    method test_empty_objects (line 39) | def test_empty_objects(self):
    method test_object_pairs_hook (line 47) | def test_object_pairs_hook(self):
    method check_keys_reuse (line 64) | def check_keys_reuse(self, source, loads):
    method test_keys_reuse_str (line 70) | def test_keys_reuse_str(self):
    method test_keys_reuse_unicode (line 74) | def test_keys_reuse_unicode(self):
    method test_empty_strings (line 78) | def test_empty_strings(self):
    method test_raw_decode (line 84) | def test_raw_decode(self):
    method test_bytes_decode (line 98) | def test_bytes_decode(self):
    method test_bounds_checking (line 110) | def test_bounds_checking(self):

FILE: simplejson/tests/test_default.py
  class TestDefault (line 5) | class TestDefault(TestCase):
    method test_default (line 6) | def test_default(self):

FILE: simplejson/tests/test_dump.py
  class MisbehavingTextSubtype (line 5) | class MisbehavingTextSubtype(text_type):
    method __str__ (line 6) | def __str__(self):
  class MisbehavingBytesSubtype (line 9) | class MisbehavingBytesSubtype(binary_type):
    method decode (line 10) | def decode(self, encoding=None):
    method __str__ (line 12) | def __str__(self):
    method __bytes__ (line 14) | def __bytes__(self):
  function as_text_type (line 17) | def as_text_type(s):
  function decode_iso_8859_15 (line 22) | def decode_iso_8859_15(b):
  class TestDump (line 25) | class TestDump(TestCase):
    method test_dump (line 26) | def test_dump(self):
    method test_constants (line 31) | def test_constants(self):
    method test_stringify_key (line 37) | def test_stringify_key(self):
    method test_dumps (line 68) | def test_dumps(self):
    method test_encode_truefalse (line 71) | def test_encode_truefalse(self):
    method test_ordered_dict (line 85) | def test_ordered_dict(self):
    method test_indent_unknown_type_acceptance (line 93) | def test_indent_unknown_type_acceptance(self):
    method test_accumulator (line 133) | def test_accumulator(self):
    method test_sort_keys (line 138) | def test_sort_keys(self):
    method test_misbehaving_text_subtype (line 147) | def test_misbehaving_text_subtype(self):
    method test_misbehaving_bytes_subtype (line 163) | def test_misbehaving_bytes_subtype(self):
    method test_bytes_toplevel (line 178) | def test_bytes_toplevel(self):
    method test_bytes_nested (line 201) | def test_bytes_nested(self):
    method test_bytes_key (line 224) | def test_bytes_key(self):

FILE: simplejson/tests/test_encode_basestring_ascii.py
  class TestEncodeBaseStringAscii (line 25) | class TestEncodeBaseStringAscii(TestCase):
    method test_py_encode_basestring_ascii (line 26) | def test_py_encode_basestring_ascii(self):
    method test_c_encode_basestring_ascii (line 29) | def test_c_encode_basestring_ascii(self):
    method _test_encode_basestring_ascii (line 34) | def _test_encode_basestring_ascii(self, encode_basestring_ascii):
    method test_sorted_dict (line 44) | def test_sorted_dict(self):

FILE: simplejson/tests/test_encode_for_html.py
  class TestEncodeForHTML (line 5) | class TestEncodeForHTML(unittest.TestCase):
    method setUp (line 7) | def setUp(self):
    method test_basic_encode (line 12) | def test_basic_encode(self):
    method test_non_ascii_basic_encode (line 18) | def test_non_ascii_basic_encode(self):
    method test_basic_roundtrip (line 24) | def test_basic_roundtrip(self):
    method test_prevent_script_breakout (line 30) | def test_prevent_script_breakout(self):

FILE: simplejson/tests/test_errors.py
  class TestErrors (line 7) | class TestErrors(TestCase):
    method test_string_keys_error (line 8) | def test_string_keys_error(self):
    method test_not_serializable (line 19) | def test_not_serializable(self):
    method test_decode_error (line 29) | def test_decode_error(self):
    method test_scan_error (line 42) | def test_scan_error(self):
    method test_error_is_pickable (line 54) | def test_error_is_pickable(self):

FILE: simplejson/tests/test_fail.py
  class TestFail (line 108) | class TestFail(TestCase):
    method test_failures (line 109) | def test_failures(self):
    method test_array_decoder_issue46 (line 122) | def test_array_decoder_issue46(self):
    method test_truncated_input (line 138) | def test_truncated_input(self):

FILE: simplejson/tests/test_float.py
  class TestFloat (line 7) | class TestFloat(TestCase):
    method test_degenerates_allow (line 8) | def test_degenerates_allow(self):
    method test_degenerates_ignore (line 15) | def test_degenerates_ignore(self):
    method test_degenerates_deny (line 19) | def test_degenerates_deny(self):
    method test_floats (line 23) | def test_floats(self):
    method test_ints (line 30) | def test_ints(self):

FILE: simplejson/tests/test_for_json.py
  class ForJson (line 5) | class ForJson(object):
    method for_json (line 6) | def for_json(self):
  class NestedForJson (line 10) | class NestedForJson(object):
    method for_json (line 11) | def for_json(self):
  class ForJsonList (line 15) | class ForJsonList(object):
    method for_json (line 16) | def for_json(self):
  class DictForJson (line 20) | class DictForJson(dict):
    method for_json (line 21) | def for_json(self):
  class ListForJson (line 25) | class ListForJson(list):
    method for_json (line 26) | def for_json(self):
  class TestForJson (line 30) | class TestForJson(unittest.TestCase):
    method assertRoundTrip (line 31) | def assertRoundTrip(self, obj, other, for_json=True):
    method test_for_json_encodes_stand_alone_object (line 41) | def test_for_json_encodes_stand_alone_object(self):
    method test_for_json_encodes_object_nested_in_dict (line 46) | def test_for_json_encodes_object_nested_in_dict(self):
    method test_for_json_encodes_object_nested_in_list_within_dict (line 51) | def test_for_json_encodes_object_nested_in_list_within_dict(self):
    method test_for_json_encodes_object_nested_within_object (line 56) | def test_for_json_encodes_object_nested_within_object(self):
    method test_for_json_encodes_list (line 61) | def test_for_json_encodes_list(self):
    method test_for_json_encodes_list_within_object (line 66) | def test_for_json_encodes_list_within_object(self):
    method test_for_json_encodes_dict_subclass (line 71) | def test_for_json_encodes_dict_subclass(self):
    method test_for_json_encodes_list_subclass (line 76) | def test_for_json_encodes_list_subclass(self):
    method test_for_json_ignored_if_not_true_with_dict_subclass (line 81) | def test_for_json_ignored_if_not_true_with_dict_subclass(self):
    method test_for_json_ignored_if_not_true_with_list_subclass (line 88) | def test_for_json_ignored_if_not_true_with_list_subclass(self):
    method test_raises_typeerror_if_for_json_not_true_with_object (line 95) | def test_raises_typeerror_if_for_json_not_true_with_object(self):

FILE: simplejson/tests/test_indent.py
  class TestIndent (line 7) | class TestIndent(TestCase):
    method test_indent (line 8) | def test_indent(self):
    method test_indent0 (line 56) | def test_indent0(self):
    method test_separators (line 71) | def test_separators(self):

FILE: simplejson/tests/test_item_sort_key.py
  class TestItemSortKey (line 6) | class TestItemSortKey(TestCase):
    method test_simple_first (line 7) | def test_simple_first(self):
    method test_case (line 13) | def test_case(self):
    method test_item_sort_key_value (line 22) | def test_item_sort_key_value(self):

FILE: simplejson/tests/test_iterable.py
  function iter_dumps (line 6) | def iter_dumps(obj, **kw):
  function sio_dump (line 9) | def sio_dump(obj, **kw):
  class TestIterable (line 14) | class TestIterable(unittest.TestCase):
    method test_iterable (line 15) | def test_iterable(self):

FILE: simplejson/tests/test_namedtuple.py
  class Value (line 9) | class Value(tuple):
    method __new__ (line 10) | def __new__(cls, *args):
    method _asdict (line 13) | def _asdict(self):
  class Point (line 15) | class Point(tuple):
    method __new__ (line 16) | def __new__(cls, *args):
    method _asdict (line 19) | def _asdict(self):
  class DuckValue (line 25) | class DuckValue(object):
    method __init__ (line 26) | def __init__(self, *args):
    method _asdict (line 29) | def _asdict(self):
  class DuckPoint (line 32) | class DuckPoint(object):
    method __init__ (line 33) | def __init__(self, *args):
    method _asdict (line 36) | def _asdict(self):
  class DeadDuck (line 39) | class DeadDuck(object):
  class DeadDict (line 42) | class DeadDict(dict):
  class TestNamedTuple (line 51) | class TestNamedTuple(unittest.TestCase):
    method test_namedtuple_dumps (line 52) | def test_namedtuple_dumps(self):
    method test_namedtuple_dumps_false (line 65) | def test_namedtuple_dumps_false(self):
    method test_namedtuple_dump (line 74) | def test_namedtuple_dump(self):
    method test_namedtuple_dump_false (line 95) | def test_namedtuple_dump_false(self):
    method test_asdict_not_callable_dump (line 106) | def test_asdict_not_callable_dump(self):
    method test_asdict_not_callable_dumps (line 116) | def test_asdict_not_callable_dumps(self):

FILE: simplejson/tests/test_pass1.py
  class TestPass1 (line 66) | class TestPass1(TestCase):
    method test_parse (line 67) | def test_parse(self):

FILE: simplejson/tests/test_pass2.py
  class TestPass2 (line 9) | class TestPass2(TestCase):
    method test_parse (line 10) | def test_parse(self):

FILE: simplejson/tests/test_pass3.py
  class TestPass3 (line 15) | class TestPass3(TestCase):
    method test_parse (line 16) | def test_parse(self):

FILE: simplejson/tests/test_raw_json.py
  class TestRawJson (line 24) | class TestRawJson(unittest.TestCase):
    method test_normal_str (line 26) | def test_normal_str(self):
    method test_raw_json_str (line 29) | def test_raw_json_str(self):
    method test_list (line 33) | def test_list(self):
    method test_direct (line 41) | def test_direct(self):

FILE: simplejson/tests/test_recursion.py
  class JSONTestObject (line 5) | class JSONTestObject:
  class RecursiveJSONEncoder (line 9) | class RecursiveJSONEncoder(json.JSONEncoder):
    method default (line 11) | def default(self, o):
  class TestRecursion (line 20) | class TestRecursion(TestCase):
    method test_listrecursion (line 21) | def test_listrecursion(self):
    method test_dictrecursion (line 44) | def test_dictrecursion(self):
    method test_defaultrecursion (line 58) | def test_defaultrecursion(self):

FILE: simplejson/tests/test_scanstring.py
  class TestScanString (line 8) | class TestScanString(TestCase):
    method test_py_scanstring (line 17) | def test_py_scanstring(self):
    method test_c_scanstring (line 20) | def test_c_scanstring(self):
    method _test_scanstring (line 27) | def _test_scanstring(self, scanstring):
    method test_issue3623 (line 135) | def test_issue3623(self):
    method test_overflow (line 141) | def test_overflow(self):
    method test_surrogates (line 148) | def test_surrogates(self):

FILE: simplejson/tests/test_separators.py
  class TestSeparators (line 7) | class TestSeparators(TestCase):
    method test_separators (line 8) | def test_separators(self):

FILE: simplejson/tests/test_speedups.py
  function has_speedups (line 12) | def has_speedups():
  function skip_if_speedups_missing (line 16) | def skip_if_speedups_missing(func):
  class BadBool (line 29) | class BadBool:
    method __bool__ (line 30) | def __bool__(self):
  class TestDecode (line 35) | class TestDecode(TestCase):
    method test_make_scanner (line 37) | def test_make_scanner(self):
    method test_bad_bool_args (line 41) | def test_bad_bool_args(self):
  class TestEncode (line 50) | class TestEncode(TestCase):
    method test_make_encoder (line 52) | def test_make_encoder(self):
    method test_bad_str_encoder (line 63) | def test_bad_str_encoder(self):
    method test_bad_bool_args (line 87) | def test_bad_bool_args(self):
    method test_int_as_string_bitcount_overflow (line 104) | def test_int_as_string_bitcount_overflow(self):
    method test_bad_encoding (line 112) | def test_bad_encoding(self):

FILE: simplejson/tests/test_str_subclass.py
  class WonkyTextSubclass (line 7) | class WonkyTextSubclass(text_type):
    method __getslice__ (line 8) | def __getslice__(self, start, end):
  class TestStrSubclass (line 11) | class TestStrSubclass(TestCase):
    method test_dump_load (line 12) | def test_dump_load(self):

FILE: simplejson/tests/test_subclass.py
  class AlternateInt (line 6) | class AlternateInt(int):
    method __repr__ (line 7) | def __repr__(self):
  class AlternateFloat (line 12) | class AlternateFloat(float):
    method __repr__ (line 13) | def __repr__(self):
  class TestSubclass (line 23) | class TestSubclass(TestCase):
    method test_int (line 24) | def test_int(self):
    method test_float (line 29) | def test_float(self):

FILE: simplejson/tests/test_tool.py
  function strip_python_stderr (line 18) | def strip_python_stderr(stderr):
  function open_temp_file (line 24) | def open_temp_file():
  class TestTool (line 33) | class TestTool(unittest.TestCase):
    method runTool (line 64) | def runTool(self, args=None, data=None):
    method test_stdin_stdout (line 77) | def test_stdin_stdout(self):
    method test_infile_stdout (line 82) | def test_infile_stdout(self):
    method test_infile_outfile (line 93) | def test_infile_outfile(self):

FILE: simplejson/tests/test_tuple.py
  class TestTuples (line 6) | class TestTuples(unittest.TestCase):
    method test_tuple_array_dumps (line 7) | def test_tuple_array_dumps(self):
    method test_tuple_array_dump (line 23) | def test_tuple_array_dump(self):

FILE: simplejson/tests/test_unicode.py
  class TestUnicode (line 8) | class TestUnicode(TestCase):
    method test_encoding1 (line 9) | def test_encoding1(self):
    method test_encoding2 (line 17) | def test_encoding2(self):
    method test_encoding3 (line 24) | def test_encoding3(self):
    method test_encoding4 (line 29) | def test_encoding4(self):
    method test_encoding5 (line 34) | def test_encoding5(self):
    method test_encoding6 (line 39) | def test_encoding6(self):
    method test_big_unicode_encode (line 44) | def test_big_unicode_encode(self):
    method test_big_unicode_decode (line 49) | def test_big_unicode_decode(self):
    method test_unicode_decode (line 54) | def test_unicode_decode(self):
    method test_object_pairs_hook_with_unicode (line 61) | def test_object_pairs_hook_with_unicode(self):
    method test_default_encoding (line 77) | def test_default_encoding(self):
    method test_unicode_preservation (line 81) | def test_unicode_preservation(self):
    method test_ensure_ascii_false_returns_unicode (line 86) | def test_ensure_ascii_false_returns_unicode(self):
    method test_ensure_ascii_false_bytestring_encoding (line 93) | def test_ensure_ascii_false_bytestring_encoding(self):
    method test_ensure_ascii_linebreak_encoding (line 104) | def test_ensure_ascii_linebreak_encoding(self):
    method test_invalid_escape_sequences (line 115) | def test_invalid_escape_sequences(self):
    method test_ensure_ascii_still_works (line 138) | def test_ensure_ascii_still_works(self):
    method test_strip_bom (line 149) | def test_strip_bom(self):

FILE: simplejson/tool.py
  function main (line 17) | def main():

FILE: train.py
  function main (line 68) | def main():
  function train (line 246) | def train(model_prime, model, fc, memory, ppo, optimizer, train_loader, ...
  function validate (line 391) | def validate(model_prime, model, fc, memory, ppo, _, val_loader, criterion,

FILE: utils.py
  function mkdir_p (line 11) | def mkdir_p(path):
  class AverageMeter (line 22) | class AverageMeter(object):
    method __init__ (line 25) | def __init__(self):
    method reset (line 28) | def reset(self):
    method update (line 34) | def update(self, val, n=1):
  function accuracy (line 41) | def accuracy(output, target, topk=(1,)):
  function get_prime (line 53) | def get_prime(images, patch_size, interpolation='bicubic'):
  function get_patch (line 59) | def get_patch(images, action_sequence, patch_size):
  function adjust_learning_rate (line 76) | def adjust_learning_rate(optimizer, train_configuration, epoch, training...
  function save_checkpoint (line 100) | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='c...

FILE: yacs/config.py
  class CfgNode (line 63) | class CfgNode(dict):
    method __init__ (line 74) | def __init__(self, init_dict=None, key_list=None, new_allowed=False):
    method _create_config_tree_from_dict (line 112) | def _create_config_tree_from_dict(cls, dic, key_list):
    method __getattr__ (line 137) | def __getattr__(self, name):
    method __setattr__ (line 143) | def __setattr__(self, name, value):
    method __str__ (line 164) | def __str__(self):
    method __repr__ (line 185) | def __repr__(self):
    method dump (line 188) | def dump(self, **kwargs):
    method merge_from_file (line 209) | def merge_from_file(self, cfg_filename):
    method merge_from_other_cfg (line 215) | def merge_from_other_cfg(self, cfg_other):
    method merge_from_list (line 219) | def merge_from_list(self, cfg_list):
    method freeze (line 248) | def freeze(self):
    method defrost (line 252) | def defrost(self):
    method is_frozen (line 256) | def is_frozen(self):
    method _immutable (line 260) | def _immutable(self, is_immutable):
    method clone (line 273) | def clone(self):
    method register_deprecated_key (line 277) | def register_deprecated_key(self, key):
    method register_renamed_key (line 287) | def register_renamed_key(self, old_name, new_name, message=None):
    method key_is_deprecated (line 301) | def key_is_deprecated(self, full_key):
    method key_is_renamed (line 308) | def key_is_renamed(self, full_key):
    method raise_key_rename_error (line 312) | def raise_key_rename_error(self, full_key):
    method is_new_allowed (line 325) | def is_new_allowed(self):
    method load_cfg (line 329) | def load_cfg(cls, cfg_file_obj_or_str):
    method _load_cfg_from_file (line 354) | def _load_cfg_from_file(cls, file_obj):
    method _load_cfg_from_yaml_str (line 368) | def _load_cfg_from_yaml_str(cls, str_obj):
    method _load_cfg_py_source (line 374) | def _load_cfg_py_source(cls, filename):
    method _decode_cfg_value (line 391) | def _decode_cfg_value(cls, value):
  function _valid_type (line 434) | def _valid_type(value, allow_cfg_node=False):
  function _merge_a_into_b (line 440) | def _merge_a_into_b(a, b, root, key_list):
  function _check_and_coerce_cfg_value_type (line 480) | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
  function _assert_with_logging (line 522) | def _assert_with_logging(cond, msg):
  function _load_module_from_file (line 528) | def _load_module_from_file(name, filename):

FILE: yacs/tests.py
  class SubCN (line 15) | class SubCN(CN):
  function get_cfg (line 19) | def get_cfg(cls=CN):
  class TestCfgNode (line 59) | class TestCfgNode(unittest.TestCase):
    method test_immutability (line 60) | def test_immutability(self):
  class TestCfg (line 88) | class TestCfg(unittest.TestCase):
    method test_copy_cfg (line 89) | def test_copy_cfg(self):
    method test_merge_cfg_from_cfg (line 96) | def test_merge_cfg_from_cfg(self):
    method test_merge_cfg_from_file (line 152) | def test_merge_cfg_from_file(self):
    method test_merge_cfg_from_list (line 163) | def test_merge_cfg_from_list(self):
    method test_deprecated_key_from_list (line 177) | def test_deprecated_key_from_list(self):
    method test_nonexistant_key_from_list (line 192) | def test_nonexistant_key_from_list(self):
    method test_load_cfg_invalid_type (line 198) | def test_load_cfg_invalid_type(self):
    method test_deprecated_key_from_file (line 204) | def test_deprecated_key_from_file(self):
    method test_renamed_key_from_list (line 219) | def test_renamed_key_from_list(self):
    method test_renamed_key_from_file (line 227) | def test_renamed_key_from_file(self):
    method test_load_cfg_from_file (line 241) | def test_load_cfg_from_file(self):
    method test_load_from_python_file (line 249) | def test_load_from_python_file(self):
    method test_invalid_type (line 259) | def test_invalid_type(self):
    method test__str__ (line 264) | def test__str__(self):
    method test_new_allowed (line 289) | def test_new_allowed(self):
    method test_new_allowed_bad (line 296) | def test_new_allowed_bad(self):
  class TestCfgNodeSubclass (line 302) | class TestCfgNodeSubclass(unittest.TestCase):
    method test_merge_cfg_from_file (line 303) | def test_merge_cfg_from_file(self):
    method test_merge_cfg_from_list (line 314) | def test_merge_cfg_from_list(self):
    method test_merge_cfg_from_cfg (line 328) | def test_merge_cfg_from_cfg(self):
Condensed preview — 103 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (656K chars).
[
  {
    "path": ".gitignore",
    "chars": 36,
    "preview": "\nmodels/.DS_Store\nfigures/.DS_Store\n"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 9117,
    "preview": "# Glance-and-Focus Networks (PyTorch)\n\nThis repo contains the official code and pre-trained models for the glance and fo"
  },
  {
    "path": "configs.py",
    "chars": 6009,
    "preview": "from PIL import Image\r\n\r\nmodel_configurations = {\r\n    'resnet50': {\r\n        'feature_num': 2048,\r\n        'feature_map"
  },
  {
    "path": "inference.py",
    "chars": 12188,
    "preview": "\r\nimport torchvision.transforms as transforms\r\nimport torchvision.datasets as datasets\r\n\r\nimport torch.multiprocessing\r\n"
  },
  {
    "path": "models/__init__.py",
    "chars": 252,
    "preview": "from .gen_efficientnet import *\nfrom .mobilenetv3 import *\nfrom .model_factory import create_model\nfrom .config import i"
  },
  {
    "path": "models/activations/__init__.py",
    "chars": 2878,
    "preview": "from .config import *\nfrom .activations_autofn import *\nfrom .activations_jit import *\nfrom .activations import *\n\n\n_ACT"
  },
  {
    "path": "models/activations/activations.py",
    "chars": 2365,
    "preview": "from torch import nn as nn\nfrom torch.nn import functional as F\n\n\ndef swish(x, inplace: bool = False):\n    \"\"\"Swish - De"
  },
  {
    "path": "models/activations/activations_autofn.py",
    "chars": 1962,
    "preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\n__all__ = ['swish_auto', 'SwishAuto', 'mi"
  },
  {
    "path": "models/activations/activations_jit.py",
    "chars": 2737,
    "preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\n\n__all__ = ['swish_jit', 'SwishJit', 'mish"
  },
  {
    "path": "models/activations/config.py",
    "chars": 527,
    "preview": "\"\"\" Global Config and Constants\n\"\"\"\n\n__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable']\n\n#"
  },
  {
    "path": "models/config.py",
    "chars": 527,
    "preview": "\"\"\" Global Config and Constants\n\"\"\"\n\n__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable']\n\n#"
  },
  {
    "path": "models/conv2d_layers.py",
    "chars": 11934,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch._six import container_abcs\n\nfrom itertools"
  },
  {
    "path": "models/densenet.py",
    "chars": 10014,
    "preview": "\nimport re\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom collections import OrderedDict\n\n__all"
  },
  {
    "path": "models/efficientnet_builder.py",
    "chars": 26128,
    "preview": "import re\nfrom copy import deepcopy\n\nfrom .conv2d_layers import *\nfrom .activations import *\n\n\n# Defaults used for Googl"
  },
  {
    "path": "models/gen_efficientnet.py",
    "chars": 56097,
    "preview": "\"\"\" Generic Efficient Networks\n\nA generic MobileNet class with building blocks to support a variety of models:\n\n* Effici"
  },
  {
    "path": "models/helpers.py",
    "chars": 2726,
    "preview": "import torch\nimport os\nfrom collections import OrderedDict\ntry:\n    from torch.hub import load_state_dict_from_url\nexcep"
  },
  {
    "path": "models/mobilenetv3.py",
    "chars": 14981,
    "preview": "\"\"\" MobileNet-V3\n\nA PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.\n\nPaper: Searching for M"
  },
  {
    "path": "models/model_factory.py",
    "chars": 655,
    "preview": "from .mobilenetv3 import *\nfrom .gen_efficientnet import *\nfrom .helpers import load_checkpoint\n\n\ndef create_model(\n    "
  },
  {
    "path": "models/resnet.py",
    "chars": 8977,
    "preview": "import torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\n\n__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50'"
  },
  {
    "path": "models/version.py",
    "chars": 22,
    "preview": "__version__ = '0.9.8'\n"
  },
  {
    "path": "network.py",
    "chars": 8809,
    "preview": "import torch\r\nimport torchvision\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport math\r\n\r\n\r\nclass Memory:"
  },
  {
    "path": "pycls/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "pycls/cfgs/RegNetY-1.6GF_dds_8gpu.yaml",
    "chars": 377,
    "preview": "MODEL:\n  TYPE: regnet\n  NUM_CLASSES: 1000\nREGNET:\n  SE_ON: True\n  DEPTH: 27\n  W0: 48\n  WA: 20.71\n  WM: 2.65\n  GROUP_W: 2"
  },
  {
    "path": "pycls/cfgs/RegNetY-600MF_dds_8gpu.yaml",
    "chars": 377,
    "preview": "MODEL:\n  TYPE: regnet\n  NUM_CLASSES: 1000\nREGNET:\n  SE_ON: True\n  DEPTH: 15\n  W0: 48\n  WA: 32.54\n  WM: 2.32\n  GROUP_W: 1"
  },
  {
    "path": "pycls/cfgs/RegNetY-800MF_dds_8gpu.yaml",
    "chars": 376,
    "preview": "MODEL:\n  TYPE: regnet\n  NUM_CLASSES: 1000\nREGNET:\n  SE_ON: True\n  DEPTH: 14\n  W0: 56\n  WA: 38.84\n  WM: 2.4\n  GROUP_W: 16"
  },
  {
    "path": "pycls/core/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "pycls/core/config.py",
    "chars": 10612,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/core/losses.py",
    "chars": 714,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/core/model_builder.py",
    "chars": 1640,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/core/old_config.py",
    "chars": 10612,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/core/optimizer.py",
    "chars": 2485,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/datasets/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "pycls/datasets/cifar10.py",
    "chars": 2777,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/datasets/imagenet.py",
    "chars": 3671,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/datasets/loader.py",
    "chars": 2570,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/datasets/paths.py",
    "chars": 822,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/datasets/transforms.py",
    "chars": 3486,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "pycls/models/anynet.py",
    "chars": 13065,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/models/effnet.py",
    "chars": 7440,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/models/regnet.py",
    "chars": 3033,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/models/resnet.py",
    "chars": 9767,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "pycls/utils/benchmark.py",
    "chars": 2776,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/checkpoint.py",
    "chars": 3046,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/distributed.py",
    "chars": 1946,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/error_handler.py",
    "chars": 1872,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/io.py",
    "chars": 2675,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/logging.py",
    "chars": 3083,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/lr_policy.py",
    "chars": 1507,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/meters.py",
    "chars": 7866,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/metrics.py",
    "chars": 3856,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/multiprocessing.py",
    "chars": 1600,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/net.py",
    "chars": 3111,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/plotting.py",
    "chars": 4656,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "pycls/utils/timer.py",
    "chars": 875,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MI"
  },
  {
    "path": "simplejson/__init__.py",
    "chars": 23745,
    "preview": "r\"\"\"JSON (JavaScript Object Notation) <http://json.org> is a subset of\nJavaScript syntax (ECMA-262 3rd edition) used as "
  },
  {
    "path": "simplejson/_speedups.c",
    "chars": 108029,
    "preview": "/* -*- mode: C; c-file-style: \"python\"; c-basic-offset: 4 -*- */\n#include \"Python.h\"\n#include \"structmember.h\"\n\n#if PY_M"
  },
  {
    "path": "simplejson/compat.py",
    "chars": 815,
    "preview": "\"\"\"Python 3 compatibility shims\n\"\"\"\nimport sys\nif sys.version_info[0] < 3:\n    PY3 = False\n    def b(s):\n        return "
  },
  {
    "path": "simplejson/decoder.py",
    "chars": 14519,
    "preview": "\"\"\"Implementation of JSONDecoder\n\"\"\"\nfrom __future__ import absolute_import\nimport re\nimport sys\nimport struct\nfrom .com"
  },
  {
    "path": "simplejson/encoder.py",
    "chars": 28570,
    "preview": "\"\"\"Implementation of JSONEncoder\n\"\"\"\nfrom __future__ import absolute_import\nimport re\nfrom operator import itemgetter\n# "
  },
  {
    "path": "simplejson/errors.py",
    "chars": 1779,
    "preview": "\"\"\"Error classes used by simplejson\n\"\"\"\n__all__ = ['JSONDecodeError']\n\n\ndef linecol(doc, pos):\n    lineno = doc.count('\\"
  },
  {
    "path": "simplejson/ordered_dict.py",
    "chars": 2945,
    "preview": "\"\"\"Drop-in replacement for collections.OrderedDict by Raymond Hettinger\n\nhttp://code.activestate.com/recipes/576693/\n\n\"\""
  },
  {
    "path": "simplejson/raw_json.py",
    "chars": 217,
    "preview": "\"\"\"Implementation of RawJSON\n\"\"\"\n\nclass RawJSON(object):\n    \"\"\"Wrap an encoded JSON document for direct embedding in th"
  },
  {
    "path": "simplejson/scanner.py",
    "chars": 2971,
    "preview": "\"\"\"JSON token scanner\n\"\"\"\nimport re\nfrom .errors import JSONDecodeError\ndef _import_c_make_scanner():\n    try:\n        f"
  },
  {
    "path": "simplejson/tests/__init__.py",
    "chars": 2148,
    "preview": "from __future__ import absolute_import\nimport unittest\nimport sys\nimport os\n\n\nclass NoExtensionTestSuite(unittest.TestSu"
  },
  {
    "path": "simplejson/tests/test_bigint_as_string.py",
    "chars": 2238,
    "preview": "from unittest import TestCase\n\nimport simplejson as json\n\n\nclass TestBigintAsString(TestCase):\n    # Python 2.5, at leas"
  },
  {
    "path": "simplejson/tests/test_bitsize_int_as_string.py",
    "chars": 2297,
    "preview": "from unittest import TestCase\n\nimport simplejson as json\n\n\nclass TestBitSizeIntAsString(TestCase):\n    # Python 2.5, at "
  },
  {
    "path": "simplejson/tests/test_check_circular.py",
    "chars": 917,
    "preview": "from unittest import TestCase\nimport simplejson as json\n\ndef default_iterable(obj):\n    return list(obj)\n\nclass TestChec"
  },
  {
    "path": "simplejson/tests/test_decimal.py",
    "chars": 2544,
    "preview": "import decimal\nfrom decimal import Decimal\nfrom unittest import TestCase\nfrom simplejson.compat import StringIO, reload_"
  },
  {
    "path": "simplejson/tests/test_decode.py",
    "chars": 4835,
    "preview": "from __future__ import absolute_import\nimport decimal\nfrom unittest import TestCase\n\nimport simplejson as json\nfrom simp"
  },
  {
    "path": "simplejson/tests/test_default.py",
    "chars": 221,
    "preview": "from unittest import TestCase\n\nimport simplejson as json\n\nclass TestDefault(TestCase):\n    def test_default(self):\n     "
  },
  {
    "path": "simplejson/tests/test_dump.py",
    "chars": 10356,
    "preview": "from unittest import TestCase\nfrom simplejson.compat import StringIO, long_type, b, binary_type, text_type, PY3\nimport s"
  },
  {
    "path": "simplejson/tests/test_encode_basestring_ascii.py",
    "chars": 2337,
    "preview": "from unittest import TestCase\n\nimport simplejson.encoder\nfrom simplejson.compat import b\n\nCASES = [\n    (u'/\\\\\"\\ucafe\\ub"
  },
  {
    "path": "simplejson/tests/test_encode_for_html.py",
    "chars": 1515,
    "preview": "import unittest\n\nimport simplejson as json\n\nclass TestEncodeForHTML(unittest.TestCase):\n\n    def setUp(self):\n        se"
  },
  {
    "path": "simplejson/tests/test_errors.py",
    "chars": 2081,
    "preview": "import sys, pickle\nfrom unittest import TestCase\n\nimport simplejson as json\nfrom simplejson.compat import text_type, b\n\n"
  },
  {
    "path": "simplejson/tests/test_fail.py",
    "chars": 6346,
    "preview": "import sys\nfrom unittest import TestCase\n\nimport simplejson as json\n\n# 2007-10-05\nJSONDOCS = [\n    # http://json.org/JSO"
  },
  {
    "path": "simplejson/tests/test_float.py",
    "chars": 1430,
    "preview": "import math\nfrom unittest import TestCase\nfrom simplejson.compat import long_type, text_type\nimport simplejson as json\nf"
  },
  {
    "path": "simplejson/tests/test_for_json.py",
    "chars": 2767,
    "preview": "import unittest\nimport simplejson as json\n\n\nclass ForJson(object):\n    def for_json(self):\n        return {'for_json': 1"
  },
  {
    "path": "simplejson/tests/test_indent.py",
    "chars": 2568,
    "preview": "from unittest import TestCase\nimport textwrap\n\nimport simplejson as json\nfrom simplejson.compat import StringIO\n\nclass T"
  },
  {
    "path": "simplejson/tests/test_item_sort_key.py",
    "chars": 1376,
    "preview": "from unittest import TestCase\n\nimport simplejson as json\nfrom operator import itemgetter\n\nclass TestItemSortKey(TestCase"
  },
  {
    "path": "simplejson/tests/test_iterable.py",
    "chars": 1390,
    "preview": "import unittest\nfrom simplejson.compat import StringIO\n\nimport simplejson as json\n\ndef iter_dumps(obj, **kw):\n    return"
  },
  {
    "path": "simplejson/tests/test_namedtuple.py",
    "chars": 4004,
    "preview": "from __future__ import absolute_import\nimport unittest\nimport simplejson as json\nfrom simplejson.compat import StringIO\n"
  },
  {
    "path": "simplejson/tests/test_pass1.py",
    "chars": 1746,
    "preview": "from unittest import TestCase\n\nimport simplejson as json\n\n# from http://json.org/JSON_checker/test/pass1.json\nJSON = r''"
  },
  {
    "path": "simplejson/tests/test_pass2.py",
    "chars": 386,
    "preview": "from unittest import TestCase\nimport simplejson as json\n\n# from http://json.org/JSON_checker/test/pass2.json\nJSON = r'''"
  },
  {
    "path": "simplejson/tests/test_pass3.py",
    "chars": 482,
    "preview": "from unittest import TestCase\n\nimport simplejson as json\n\n# from http://json.org/JSON_checker/test/pass3.json\nJSON = r''"
  },
  {
    "path": "simplejson/tests/test_raw_json.py",
    "chars": 1062,
    "preview": "import unittest\nimport simplejson as json\n\ndct1 = {\n    'key1': 'value1'\n}\n\ndct2 = {\n    'key2': 'value2',\n    'd1': dct"
  },
  {
    "path": "simplejson/tests/test_recursion.py",
    "chars": 1679,
    "preview": "from unittest import TestCase\n\nimport simplejson as json\n\nclass JSONTestObject:\n    pass\n\n\nclass RecursiveJSONEncoder(js"
  },
  {
    "path": "simplejson/tests/test_scanstring.py",
    "chars": 7398,
    "preview": "import sys\nfrom unittest import TestCase\n\nimport simplejson as json\nimport simplejson.decoder\nfrom simplejson.compat imp"
  },
  {
    "path": "simplejson/tests/test_separators.py",
    "chars": 942,
    "preview": "import textwrap\nfrom unittest import TestCase\n\nimport simplejson as json\n\n\nclass TestSeparators(TestCase):\n    def test_"
  },
  {
    "path": "simplejson/tests/test_speedups.py",
    "chars": 4144,
    "preview": "from __future__ import with_statement\n\nimport sys\nimport unittest\nfrom unittest import TestCase\n\nimport simplejson\nfrom "
  },
  {
    "path": "simplejson/tests/test_str_subclass.py",
    "chars": 740,
    "preview": "from unittest import TestCase\n\nimport simplejson\nfrom simplejson.compat import text_type\n\n# Tests for issue demonstrated"
  },
  {
    "path": "simplejson/tests/test_subclass.py",
    "chars": 1124,
    "preview": "from unittest import TestCase\nimport simplejson as json\n\nfrom decimal import Decimal\n\nclass AlternateInt(int):\n    def _"
  },
  {
    "path": "simplejson/tests/test_tool.py",
    "chars": 3304,
    "preview": "from __future__ import with_statement\nimport os\nimport sys\nimport textwrap\nimport unittest\nimport subprocess\nimport temp"
  },
  {
    "path": "simplejson/tests/test_tuple.py",
    "chars": 1831,
    "preview": "import unittest\n\nfrom simplejson.compat import StringIO\nimport simplejson as json\n\nclass TestTuples(unittest.TestCase):\n"
  },
  {
    "path": "simplejson/tests/test_unicode.py",
    "chars": 7056,
    "preview": "import sys\nimport codecs\nfrom unittest import TestCase\n\nimport simplejson as json\nfrom simplejson.compat import unichr, "
  },
  {
    "path": "simplejson/tool.py",
    "chars": 1136,
    "preview": "r\"\"\"Command-line tool to validate and pretty-print JSON\n\nUsage::\n\n    $ echo '{\"json\":\"obj\"}' | python -m simplejson.too"
  },
  {
    "path": "train.py",
    "chars": 20851,
    "preview": "import time\r\n\r\nimport torchvision.transforms as transforms\r\nimport torchvision.datasets as datasets\r\n\r\nimport torch.mult"
  },
  {
    "path": "utils.py",
    "chars": 3315,
    "preview": "import os\r\nimport errno\r\nimport math\r\nimport shutil\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional a"
  },
  {
    "path": "yacs/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "yacs/config.py",
    "chars": 19682,
    "preview": "# Copyright (c) 2018-present, Facebook, Inc.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you m"
  },
  {
    "path": "yacs/tests.py",
    "chars": 10195,
    "preview": "import logging\nimport tempfile\nimport unittest\n\nimport yacs.config\nfrom yacs.config import CfgNode as CN\n\ntry:\n    _igno"
  }
]

About this extraction

This page contains the full source code of the blackfeather-wang/GFNet-Pytorch GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 103 files (610.4 KB), approximately 162.8k tokens, and a symbol index with 986 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!